From f99fe281cbb6519b7b5f1199c570d496ad4df474 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:38:26 -0400 Subject: [PATCH 001/356] Add LoRA+ support --- library/train_util.py | 2 ++ networks/dylora.py | 45 ++++++++++++++++++++++++++---------- networks/lora.py | 54 ++++++++++++++++++++++++++++--------------- train_network.py | 2 +- 4 files changed, 71 insertions(+), 32 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d2b69edb..4e5ab737 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/dylora.py b/networks/dylora.py index 637f3345..a73ade8b 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -406,27 +406,48 @@ class DyLoRANetwork(torch.nn.Module): logger.info(f"weights are merged") """ - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 948b30b0..8d761977 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,21 +1035,43 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1063,21 +1085,15 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/train_network.py b/train_network.py index e0fa6945..ba0c124d 100644 --- a/train_network.py +++ b/train_network.py @@ -339,7 +339,7 @@ class NetworkTrainer: # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" From c7691607ea1647864b5149c98434a27f23386c65 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:43:04 -0400 Subject: [PATCH 002/356] Add LoRA-FA for LoRA+ --- networks/lora_fa.py | 58 +++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce..fcc503e8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,22 +1033,43 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras: List[LoRAModule]): - params = [] + def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - # params.extend(lora.parameters()) - params.extend(lora.get_trainable_params()) + for name, param in lora.get_trainable_named_params(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1062,21 +1083,15 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params @@ -1093,6 +1108,9 @@ class LoRANetwork(torch.nn.Module): def get_trainable_params(self): return self.parameters() + def get_trainable_named_params(self): + return self.named_parameters() + def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None From 1933ab4b4848b1f8b578c10f25bd050f5e246ac0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 3 Apr 2024 12:46:34 -0400 Subject: [PATCH 003/356] Fix default_lr being applied --- networks/dylora.py | 21 ++++++++++++++++++--- networks/lora.py | 30 +++++++++++++++++++++++------- networks/lora_fa.py | 30 +++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index a73ade8b..edc3e222 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -407,7 +407,14 @@ class DyLoRANetwork(torch.nn.Module): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -442,11 +449,19 @@ class DyLoRANetwork(torch.nn.Module): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 8d761977..e082941e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,7 +1035,14 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1070,7 +1077,11 @@ class LoRANetwork(torch.nn.Module): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1085,14 +1096,19 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora_fa.py b/networks/lora_fa.py index fcc503e8..3f6774dd 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,7 +1033,14 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1068,7 +1075,11 @@ class LoRANetwork(torch.nn.Module): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1083,14 +1094,19 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params From 75833e84a1c7e3c2fb0a9e3ce0fe3d8c1758a012 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 8 Apr 2024 19:23:02 -0400 Subject: [PATCH 004/356] Fix default LR, Add overall LoRA+ ratio, Add log `--loraplus_ratio` added for both TE and UNet Add log for lora+ --- library/train_util.py | 1 + networks/dylora.py | 24 ++++++------- networks/lora.py | 28 ++++++++-------- networks/lora_fa.py | 30 ++++++++--------- train_network.py | 78 ++++++++++++++++++++++++++++++++----------- 5 files changed, 101 insertions(+), 60 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e5ab737..7c2bf693 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") diff --git a/networks/dylora.py b/networks/dylora.py index edc3e222..dc5c7cb3 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,32 +412,32 @@ class DyLoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_B" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -452,7 +452,7 @@ class DyLoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -460,7 +460,7 @@ class DyLoRANetwork(torch.nn.Module): params = assemble_params( self.unet_loras, default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index e082941e..6cb05bcb 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,32 +1040,32 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1080,7 +1080,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1099,15 +1099,15 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 3f6774dd..2eff86d6 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,32 +1038,32 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: - for name, param in lora.get_trainable_named_params(): - if lora_plus_ratio is not None and "lora_up" in name: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1078,7 +1078,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1097,15 +1097,15 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index ba0c124d..43226fc4 100644 --- a/train_network.py +++ b/train_network.py @@ -66,14 +66,61 @@ class NetworkTrainer: lrs = lr_scheduler.get_last_lr() - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: + if len(lrs) > 4: + idx = 0 + if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + lora_plus = "" + group_id = i + + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + lora_plus = '_lora+' if i % 2 == 1 else '' + group_id = int((i / 2) + (i % 2 + 0.5)) + + logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + + else: + if args.network_train_text_encoder_only: + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + + elif args.network_train_unet_only: + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + else: + logs["lr/unet"] = float(lrs[0]) else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + if len(lrs) == 2: + if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: + logs["lr/all"] = float(lrs[0]) + logs["lr/all_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) + elif len(lrs) == 4: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + logs["lr/unet"] = float(lrs[2]) + logs["lr/unet_lora+"] = float(lrs[3]) + else: + logs["lr/all"] = float(lrs[0]) if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -81,18 +128,6 @@ class NetworkTrainer: logs["lr/d*lr"] = ( lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] ) - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) return logs @@ -339,7 +374,7 @@ class NetworkTrainer: # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -348,6 +383,11 @@ class NetworkTrainer: optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + assert ( + (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) + ), "LoRA+ and Prodigy/DAdaptation is not supported" + # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers From 68467bdf4d76ba2c57289209b0ffd6ba599e2080 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 11 Apr 2024 17:33:19 -0400 Subject: [PATCH 005/356] Fix unset or invalid LR from making a param_group --- networks/dylora.py | 4 ++-- networks/lora.py | 5 +++-- networks/lora_fa.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index dc5c7cb3..0546fc7a 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,8 +412,8 @@ class DyLoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -441,7 +441,7 @@ class DyLoRANetwork(torch.nn.Module): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data) diff --git a/networks/lora.py b/networks/lora.py index 6cb05bcb..d74608fe 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,8 +1040,8 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1069,7 +1069,8 @@ class LoRANetwork(torch.nn.Module): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + print("NO LR skipping!") continue params.append(param_data) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 2eff86d6..9a608118 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,8 +1038,8 @@ class LoRANetwork(torch.nn.Module): text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1067,7 +1067,7 @@ class LoRANetwork(torch.nn.Module): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data) From 4f203ce40d3a4647d52a2570a228e279dd04b321 Mon Sep 17 00:00:00 2001 From: 2kpr <96332338+2kpr@users.noreply.github.com> Date: Sun, 14 Apr 2024 09:56:58 -0500 Subject: [PATCH 006/356] Fused backward pass --- library/adafactor_fused.py | 106 +++++++++++++++++++++++++++++++++++++ library/train_util.py | 13 +++++ sdxl_train.py | 29 +++++++--- 3 files changed, 142 insertions(+), 6 deletions(-) create mode 100644 library/adafactor_fused.py diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 00000000..bdfc32ce --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,106 @@ +import math +import torch +from transformers import Adafactor + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..46b55c03 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,6 +2920,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -3846,6 +3851,14 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..3b28575e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -430,6 +430,20 @@ def train(args): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -619,13 +633,16 @@ def train(args): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From 64916a35b2378c4a8cdf3e9efeef8b8ab7ccb41c Mon Sep 17 00:00:00 2001 From: Zovjsra <4703michael@gmail.com> Date: Tue, 16 Apr 2024 16:40:08 +0800 Subject: [PATCH 007/356] add disable_mmap to args --- library/sdxl_model_util.py | 14 +++++++++----- library/sdxl_train_util.py | 9 +++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f03f1bae..e6fcb1f9 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,5 @@ import torch +import safetensors from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file @@ -163,17 +164,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None): raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): # model_version is reserved for future use # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - try: - state_dict = load_file(ckpt_path, device=map_location) - except: - state_dict = load_file(ckpt_path) # prevent device invalid Error + if(disable_mmap): + state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read()) + else: + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e3..106c5b45 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, + args.disable_mmap_load_safetensors ) # work on low-ram device @@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model( - name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False ): # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path @@ -75,7 +76,7 @@ def _load_target_model( unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + ) def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): From feefcf256e78a5f8d60c3a940f2be3b5c3ca335d Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:15:36 -0400 Subject: [PATCH 008/356] Display name of error latent file When trying to load stored latents, if an error occurs, this change will tell you what file failed to load Currently it will just tell you that something failed without telling you which file --- library/train_util.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..58527fa0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2123,17 +2123,20 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? + return False + if npz["latents"].shape[1:3] != expected_latents_size: + return False - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + except: + raise RuntimeError(f"Error loading file: {npz_path}") return True From fc374375de4fc9efd10eb598fdc166a4b6d0ad17 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:29:01 -0400 Subject: [PATCH 009/356] Allow negative learning rate This can be used to train away from a group of images you don't want As this moves the model away from a point instead of towards it, the change in the model is unbounded So, don't set it too low. -4e-7 seemed to work well. --- sdxl_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..1e6cec1a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def train(args): # 学習を準備する:モデルを適切な状態にする if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - train_unet = args.learning_rate > 0 + train_unet = args.learning_rate != 0 train_text_encoder1 = False train_text_encoder2 = False @@ -284,8 +284,8 @@ def train(args): text_encoder2.gradient_checkpointing_enable() lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_text_encoder1 = lr_te1 > 0 - train_text_encoder2 = lr_te2 > 0 + train_text_encoder1 = lr_te1 != 0 + train_text_encoder2 = lr_te2 != 0 # caching one text encoder output is not supported if not train_text_encoder1: From 2c9db5d9f2f6b57f15b9312139d0410ae8ae4f3c Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:11:43 +0100 Subject: [PATCH 010/356] passing filtered hyperparameters to accelerate --- fine_tune.py | 2 +- library/train_util.py | 14 ++++++++++++++ sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 10 files changed, 23 insertions(+), 9 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c7e6bbd2..77a1a4f3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..40be2b05 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3378,6 +3378,20 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) +def filter_sensitive_args(args: argparse.Namespace): + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + return filtered_args # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..5a9aa214 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -487,7 +487,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first sdxl_train_util.sample_images( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628..770a1f3d 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -353,7 +353,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e85e978c..9490cf6f 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d..793f79c7 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -344,7 +344,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_db.py b/train_db.py index 1de504ed..4f901829 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/train_network.py b/train_network.py index c99d3724..1dca437c 100644 --- a/train_network.py +++ b/train_network.py @@ -753,7 +753,7 @@ class NetworkTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce267..56a38739 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ class TextualInversionTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) # function for saving/removing diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ddd03d53..69178523 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) # function for saving/removing From 4477116a64bb6c363d0fd9fbf3e21bb813548dfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 20 Apr 2024 21:26:09 +0800 Subject: [PATCH 011/356] fix train controlnet --- library/train_util.py | 4 ++-- requirements.txt | 1 + train_controlnet.py | 8 ++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..ecf3345f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1982,8 +1982,8 @@ class ControlNetDataset(BaseDataset): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) def __len__(self): return self.dreambooth_dataset_delegate.__len__() diff --git a/requirements.txt b/requirements.txt index e99775b8..9495dab2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +omegaconf==2.3.0 # for Image utils imagesize==1.4.1 # for BLIP captioning diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d..763041aa 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -5,7 +5,7 @@ import os import random import time from multiprocessing import Value -from types import SimpleNamespace +from omegaconf import OmegaConf import toml from tqdm import tqdm @@ -148,8 +148,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": [5, 10, 20, 20], "num_class_embeds": None, "only_cross_attention": False, "out_channels": 4, @@ -179,8 +181,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": 8, "out_channels": 4, "sample_size": 64, "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], @@ -193,7 +197,7 @@ def train(args): "resnet_time_scale_shift": "default", "projection_class_embeddings_input_dim": None, } - unet.config = SimpleNamespace(**unet.config) + unet.config = OmegaConf.create(unet.config) controlnet = ControlNetModel.from_unet(unet) From b886d0a359526f5715f3ced05697d406a169055b Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:36:47 +0100 Subject: [PATCH 012/356] Cleaned typing to be in line with accelerate hyperparameters type resctrictions --- library/train_util.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 40be2b05..75b3420d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3390,7 +3390,20 @@ def filter_sensitive_args(args: argparse.Namespace): "output_dir", "logging_dir", ] - filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values + if k not in sensitive_args + sensitive_path_args: + #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + return filtered_args # verify command line args for training From 5cb145d13bd9fae307a8766f4088b95f01492580 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 20 Apr 2024 21:56:24 +0800 Subject: [PATCH 013/356] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ecf3345f..15c23f3c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1982,8 +1982,8 @@ class ControlNetDataset(BaseDataset): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def __len__(self): return self.dreambooth_dataset_delegate.__len__() From 52652cba1a419cd72851c3882f1f877670d889c5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 21 Apr 2024 17:41:32 +0900 Subject: [PATCH 014/356] disable main process check for deepspeed #1247 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c99d3724..3a525516 100644 --- a/train_network.py +++ b/train_network.py @@ -474,7 +474,8 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): From 0540c33acac223b672da05e40edcfb3b6a35c0da Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 21 Apr 2024 17:45:29 +0900 Subject: [PATCH 015/356] pop weights if available #1247 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3a525516..aad5a719 100644 --- a/train_network.py +++ b/train_network.py @@ -481,7 +481,8 @@ class NetworkTrainer: if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") def load_model_hook(models, input_dir): From 040e26ff1d8f855f52cdfb62781e06284c5e9e34 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Sun, 21 Apr 2024 13:46:31 -0400 Subject: [PATCH 016/356] Regenerate failed file If a latent file fails to load, print out the path and the error, then return false to regenerate it --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 58527fa0..4168a41f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2135,8 +2135,10 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False - except: - raise RuntimeError(f"Error loading file: {npz_path}") + except Exception as e: + print(npz_path) + print(e) + return False return True From fdbb03c360777562e91ab1884ed7cf2c3d65611b Mon Sep 17 00:00:00 2001 From: frodo821 Date: Tue, 23 Apr 2024 14:29:05 +0900 Subject: [PATCH 017/356] removed unnecessary `torch` import on line 115 as per #1290 --- finetune/tag_images_by_wd14_tagger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd6..b3f9cdd2 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -112,7 +112,6 @@ def main(args): # モデルを読み込む if args.onnx: - import torch import onnx import onnxruntime as ort From 969f82ab474024865d292afd96768e817c9374c1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 29 Apr 2024 20:04:25 +0900 Subject: [PATCH 018/356] move loraplus args from args to network_args, simplify log lr desc --- library/train_util.py | 3 -- networks/lora.py | 58 ++++++++++++++------- train_network.py | 114 ++++++++++++++++-------------------------- 3 files changed, 84 insertions(+), 91 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 048ed2ce..15c23f3c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,9 +2920,6 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) - parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") - parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") - parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/lora.py b/networks/lora.py index edbbdc0d..b67c59bd 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -490,6 +490,14 @@ def create_network( varbose=True, ) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) @@ -1033,18 +1041,27 @@ class LoRANetwork(torch.nn.Module): return lr_weight + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + self.requires_grad_(True) + all_params = [] + lr_descriptions = [] def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} @@ -1056,6 +1073,7 @@ class LoRANetwork(torch.nn.Module): param_groups["lora"][f"{lora.lora_name}.{name}"] = param params = [] + descriptions = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} @@ -1069,20 +1087,22 @@ class LoRANetwork(torch.nn.Module): param_data["lr"] = lr if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: - print("NO LR skipping!") + logger.info("NO LR skipping!") continue params.append(param_data) + descriptions.append("plus" if key == "plus" else "") - return params + return params, descriptions if self.text_encoder_loras: - params = assemble_params( + params, descriptions = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) if self.unet_loras: if self.block_lr: @@ -1096,22 +1116,24 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - params = assemble_params( + params, descriptions = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) else: - params = assemble_params( + params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) - return all_params + return all_params, lr_descriptions def enable_gradient_checkpointing(self): # not supported diff --git a/train_network.py b/train_network.py index 9670490a..c43241e8 100644 --- a/train_network.py +++ b/train_network.py @@ -53,7 +53,15 @@ class NetworkTrainer: # TODO 他のスクリプトと共通化する def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -63,68 +71,25 @@ class NetworkTrainer: logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() - - if len(lrs) > 4: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - lora_plus = "" - group_id = i - - if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - lora_plus = '_lora+' if i % 2 == 1 else '' - group_id = int((i / 2) + (i % 2 + 0.5)) - - logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - - else: - if args.network_train_text_encoder_only: - if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - else: - logs["lr/textencoder"] = float(lrs[0]) - - elif args.network_train_unet_only: - if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - logs["lr/unet"] = float(lrs[0]) - logs["lr/unet_lora+"] = float(lrs[1]) - else: - logs["lr/unet"] = float(lrs[0]) + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] else: - if len(lrs) == 2: - if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: - logs["lr/unet"] = float(lrs[0]) - logs["lr/unet_lora+"] = float(lrs[1]) - elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: - logs["lr/all"] = float(lrs[0]) - logs["lr/all_lora+"] = float(lrs[1]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) - elif len(lrs) == 4: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - logs["lr/unet"] = float(lrs[2]) - logs["lr/unet_lora+"] = float(lrs[3]) + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" else: - logs["lr/all"] = float(lrs[0]) + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) return logs @@ -358,6 +323,7 @@ class NetworkTrainer: network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: + # FIXME consider alpha of weights info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -373,20 +339,23 @@ class NetworkTrainer: # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) + results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) + # accelerator.print( + # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + # ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + lr_descriptions = None + print(lr_descriptions) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - assert ( - (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) - ), "LoRA+ and Prodigy/DAdaptation is not supported" - # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -992,7 +961,9 @@ class NetworkTrainer: progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + ) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") return parser From dbb7bb288e416dae56d2911077e2642ad0f4b20d Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Thu, 2 May 2024 17:39:35 -0400 Subject: [PATCH 019/356] Fix caption_separator missing in subset schema --- library/config_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/config_util.py b/library/config_util.py index d75d03b0..0276acb1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -191,6 +191,7 @@ class ConfigSanitizer: "keep_tokens": int, "keep_tokens_separator": str, "secondary_separator": str, + "caption_separator": str, "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), From 8db0cadcee47005feef5be34cbfaac8b85fe8837 Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Thu, 2 May 2024 18:08:28 -0400 Subject: [PATCH 020/356] Add caption_separator to output for subset --- library/config_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/config_util.py b/library/config_util.py index d75d03b0..97554bbe 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -523,6 +523,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} + caption_separator: {subset.caption_separator} secondary_separator: {subset.secondary_separator} enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} From 58c2d856ae6da6d6962cbfdd98c8a93eb790cbde Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 May 2024 22:18:20 +0900 Subject: [PATCH 021/356] support block dim/lr for sdxl --- networks/lora.py | 271 +++++++++++++++++++++++++++-------------------- train_network.py | 4 +- 2 files changed, 156 insertions(+), 119 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index b67c59bd..61b8cd5a 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,6 +12,7 @@ import numpy as np import torch import re from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel setup_logging() import logging @@ -385,14 +386,14 @@ class LoRAInfModule(LoRAModule): return out -def parse_block_lr_kwargs(nw_kwargs): +def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]: down_lr_weight = nw_kwargs.get("down_lr_weight", None) mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) up_lr_weight = nw_kwargs.get("up_lr_weight", None) # 以上のいずれにも設定がない場合は無効としてNoneを返す if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: - return None, None, None + return None # extract learning rate weight for each block if down_lr_weight is not None: @@ -401,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs): down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) + mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")] if up_lr_weight is not None: if "," in up_lr_weight: up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + return get_block_lr_weight( + is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) ) - return down_lr_weight, mid_lr_weight, up_lr_weight - def create_network( multiplier: float, @@ -424,6 +423,9 @@ def create_network( neuron_dropout: Optional[float] = None, **kwargs, ): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -441,21 +443,21 @@ def create_network( # block dim/alpha/lr block_dims = kwargs.get("block_dims", None) - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする - if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + if block_dims is not None or block_lr_weight is not None: block_alphas = kwargs.get("block_alphas", None) conv_block_dims = kwargs.get("conv_block_dims", None) conv_block_alphas = kwargs.get("conv_block_alphas", None) block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ) # remove block dim/alpha without learning rate block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight ) else: @@ -488,6 +490,7 @@ def create_network( conv_block_dims=conv_block_dims, conv_block_alphas=conv_block_alphas, varbose=True, + is_sdxl=is_sdxl, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) @@ -498,8 +501,8 @@ def create_network( loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) return network @@ -509,9 +512,13 @@ def create_network( # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている def get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ): - num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + if not is_sdxl: + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS + else: + # 1+9+3+9+1=23, no LoRA for emb_layers (0) + num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 def parse_ints(s): return [int(i) for i in s.split(",")] @@ -522,9 +529,10 @@ def get_block_dims_and_alphas( # block_dimsとblock_alphasをパースする。必ず値が入る if block_dims is not None: block_dims = parse_ints(block_dims) - assert ( - len(block_dims) == num_total_blocks - ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + assert len(block_dims) == num_total_blocks, ( + f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given" + + f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})" + ) else: logger.warning( f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" @@ -575,15 +583,25 @@ def get_block_dims_and_alphas( return block_dims, block_alphas, conv_block_dims, conv_block_alphas -# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく +# 戻り値は block ごとの倍率のリスト def get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold -) -> Tuple[List[float], List[float], List[float]]: + is_sdxl, + down_lr_weight: Union[str, List[float]], + mid_lr_weight: List[float], + up_lr_weight: Union[str, List[float]], + zero_threshold: float, +) -> Optional[List[float]]: # パラメータ未指定時は何もせず、今までと同じ動作とする if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return None, None, None + return None - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + if not is_sdxl: + max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS + else: + max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS def get_list(name_with_suffix) -> List[float]: import math @@ -593,15 +611,18 @@ def get_block_lr_weight( base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + return [ + math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr + for i in reversed(range(max_len_for_down_or_up)) + ] elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)] elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] + return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)] elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))] elif name == "zeros": - return [0.0 + base_lr] * max_len + return [0.0 + base_lr] * max_len_for_down_or_up else: logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" @@ -614,20 +635,36 @@ def get_block_lr_weight( if type(up_lr_weight) == str: up_lr_weight = get_list(up_lr_weight) - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] + if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up) + up_lr_weight = up_lr_weight[:max_len_for_down_or_up] + down_lr_weight = down_lr_weight[:max_len_for_down_or_up] - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid: + logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid) + logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight[:max_len_for_mid] - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up) + logger.warning( + "down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up + ) + + if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up: + down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up: + up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight)) + + if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid: + logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid) + logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): logger.info("apply block learning rate / 階層別学習率を適用します。") @@ -635,72 +672,84 @@ def get_block_lr_weight( down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: + down_lr_weight = [1.0] * max_len_for_down_or_up logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: - mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight] logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - logger.info("mid_lr_weight: 1.0") + mid_lr_weight = [1.0] * max_len_for_mid + logger.info("mid_lr_weight: all 1.0, すべて1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: + up_lr_weight = [1.0] * max_len_for_down_or_up logger.info("up_lr_weight: all 1.0, すべて1.0") - return down_lr_weight, mid_lr_weight, up_lr_weight + lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight + + if is_sdxl: + lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out + + assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or ( + is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + ), f"lr_weight length is invalid: {len(lr_weight)}" + + return lr_weight # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく def remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]] ): - # set 0 to block dim without learning rate to remove the block - if down_lr_weight != None: - for i, lr in enumerate(down_lr_weight): + if block_lr_weight is not None: + for i, lr in enumerate(block_lr_weight): if lr == 0: block_dims[i] = 0 if conv_block_dims is not None: conv_block_dims[i] = 0 - if mid_lr_weight != None: - if mid_lr_weight == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if up_lr_weight != None: - for i, lr in enumerate(up_lr_weight): - if lr == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - return block_dims, block_alphas, conv_block_dims, conv_block_alphas # 外部から呼び出す可能性を考慮しておく -def get_block_index(lora_name: str) -> int: +def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: block_idx = -1 # invalid lora name + if not is_sdxl: + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 - m = RE_UPDOWN.search(lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + else: + # copy from sdxl_train + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA + block_idx = 0 # 0 + elif name.startswith("input_blocks_"): # 1-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10-12 + block_idx = 10 + int(name.split("_")[2]) + elif name.startswith("output_blocks_"): # 13-21 + block_idx = 13 + int(name.split("_")[2]) + elif name.startswith("out_"): # 22, out, no LoRA + block_idx = 22 return block_idx @@ -742,15 +791,18 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh ) # block lr - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + block_lr_weight = parse_block_lr_kwargs(kwargs) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) return network, weights_sd class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + NUM_OF_MID_BLOCKS = 1 + SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23 + SDXL_NUM_OF_MID_BLOCKS = 3 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] @@ -782,6 +834,7 @@ class LoRANetwork(torch.nn.Module): modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, + is_sdxl: Optional[bool] = False, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -863,7 +916,7 @@ class LoRANetwork(torch.nn.Module): alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: # U-Netでblock_dims指定あり - block_idx = get_block_index(lora_name) + block_idx = get_block_index(lora_name, is_sdxl) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] alpha = block_alphas[block_idx] @@ -927,15 +980,13 @@ class LoRANetwork(torch.nn.Module): skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - logger.warning( + logger.warn( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: logger.info(f"\t{name}") - self.up_lr_weight: List[float] = None - self.down_lr_weight: List[float] = None - self.mid_lr_weight: float = None + self.block_lr_weight = None self.block_lr = False # assertion @@ -966,12 +1017,12 @@ class LoRANetwork(torch.nn.Module): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - logger.info("enable LoRA for text encoder") + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") else: self.text_encoder_loras = [] if apply_unet: - logger.info("enable LoRA for U-Net") + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") else: self.unet_loras = [] @@ -1012,34 +1063,14 @@ class LoRANetwork(torch.nn.Module): logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない - def set_block_lr_weight( - self, - up_lr_weight: List[float] = None, - mid_lr_weight: float = None, - down_lr_weight: List[float] = None, - ): + def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]): self.block_lr = True - self.down_lr_weight = down_lr_weight - self.mid_lr_weight = mid_lr_weight - self.up_lr_weight = up_lr_weight + self.block_lr_weight = block_lr_weight - def get_lr_weight(self, lora: LoRAModule) -> float: - lr_weight = 1.0 - block_idx = get_block_index(lora.lora_name) - if block_idx < 0: - return lr_weight - - if block_idx < LoRANetwork.NUM_OF_BLOCKS: - if self.down_lr_weight != None: - lr_weight = self.down_lr_weight[block_idx] - elif block_idx == LoRANetwork.NUM_OF_BLOCKS: - if self.mid_lr_weight != None: - lr_weight = self.mid_lr_weight - elif block_idx > LoRANetwork.NUM_OF_BLOCKS: - if self.up_lr_weight != None: - lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] - - return lr_weight + def get_lr_weight(self, block_idx: int) -> float: + if not self.block_lr or self.block_lr_weight is None: + return 1.0 + return self.block_lr_weight[block_idx] def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): self.loraplus_lr_ratio = loraplus_lr_ratio @@ -1106,10 +1137,16 @@ class LoRANetwork(torch.nn.Module): if self.unet_loras: if self.block_lr: + is_sdxl = False + for lora in self.unet_loras: + if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + is_sdxl = True + break + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 block_idx_to_lora = {} for lora in self.unet_loras: - idx = get_block_index(lora.lora_name) + idx = get_block_index(lora.lora_name, is_sdxl) if idx not in block_idx_to_lora: block_idx_to_lora[idx] = [] block_idx_to_lora[idx].append(lora) @@ -1118,7 +1155,7 @@ class LoRANetwork(torch.nn.Module): for idx, block_loras in block_idx_to_lora.items(): params, descriptions = assemble_params( block_loras, - (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index c43241e8..2976f763 100644 --- a/train_network.py +++ b/train_network.py @@ -346,13 +346,13 @@ class NetworkTrainer: else: trainable_params = results lr_descriptions = None - except TypeError: + except TypeError as e: + # logger.warning(f"{e}") # accelerator.print( # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" # ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) lr_descriptions = None - print(lr_descriptions) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) From 52e64c69cf249a7bc4ca6f4eebe82bc1b70e617b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 4 May 2024 18:43:52 +0900 Subject: [PATCH 022/356] add debug log --- train_network.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train_network.py b/train_network.py index 2976f763..feb455ce 100644 --- a/train_network.py +++ b/train_network.py @@ -354,6 +354,16 @@ class NetworkTrainer: trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) lr_descriptions = None + # if len(trainable_params) == 0: + # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") + # for params in trainable_params: + # for k, v in params.items(): + # if type(v) == float: + # pass + # else: + # v = len(v) + # accelerator.print(f"trainable_params: {k} = {v}") + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する From 7fe81502d04c1f68c85f276517e7144e6378c484 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 11:09:32 +0900 Subject: [PATCH 023/356] update loraplus on dylora/lofa_fa --- networks/dylora.py | 46 ++++++++++++++++++++++++--------------- networks/lora.py | 7 +++++- networks/lora_fa.py | 52 +++++++++++++++++++++++++++++++-------------- 3 files changed, 71 insertions(+), 34 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index 0546fc7a..0d1701de 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -18,10 +18,13 @@ from transformers import CLIPTextModel import torch from torch import nn from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class DyLoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -195,7 +198,7 @@ def create_network( conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) - + if unit is not None: unit = int(unit) else: @@ -211,6 +214,16 @@ def create_network( unit=unit, varbose=True, ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module): self.alpha = alpha self.apply_to_conv = apply_to_conv + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info("create LoRA network from weights") else: @@ -320,9 +337,9 @@ class DyLoRANetwork(torch.nn.Module): lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras - + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - + self.text_encoder_loras = [] for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: @@ -331,7 +348,7 @@ class DyLoRANetwork(torch.nn.Module): else: index = None logger.info("create LoRA for Text Encoder") - + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) @@ -346,6 +363,11 @@ class DyLoRANetwork(torch.nn.Module): self.unet_loras = create_modules(True, unet, target_modules) logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -407,15 +429,7 @@ class DyLoRANetwork(torch.nn.Module): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -452,15 +466,13 @@ class DyLoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) if self.unet_loras: params = assemble_params( - self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_loraplus_ratio or loraplus_ratio + self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index 61b8cd5a..6e564557 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -499,7 +499,8 @@ def create_network( loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None - network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) if block_lr_weight is not None: network.set_block_lr_weight(block_lr_weight) @@ -855,6 +856,10 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 9a608118..58bcb220 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -15,8 +15,10 @@ import numpy as np import torch import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -504,6 +506,15 @@ def create_network( if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -529,7 +540,9 @@ def get_block_dims_and_alphas( len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -803,11 +816,17 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -815,9 +834,13 @@ class LoRANetwork(torch.nn.Module): logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -939,6 +962,11 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -1033,15 +1061,7 @@ class LoRANetwork(torch.nn.Module): return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -1078,7 +1098,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) @@ -1097,7 +1117,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) @@ -1105,7 +1125,7 @@ class LoRANetwork(torch.nn.Module): params = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) From 3fd8cdc55d7d87ceca2dc1127a807a7ddafb15ae Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 14:03:19 +0900 Subject: [PATCH 024/356] fix dylora loraplus --- networks/dylora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index 0d1701de..d57e3d58 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -466,13 +466,13 @@ class DyLoRANetwork(torch.nn.Module): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) if self.unet_loras: params = assemble_params( - self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio + self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio ) all_params.extend(params) From 017b82ebe33a2199c8f842c99905f59c54292f56 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 15:05:42 +0900 Subject: [PATCH 025/356] update help message for fused_backward_pass --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 46b55c03..e3c0229a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2923,7 +2923,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) From b56d5f7801dea45cdbbba8498544e8d2853ad6d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 21:35:39 +0900 Subject: [PATCH 026/356] add experimental option to fuse params to optimizer groups --- sdxl_train.py | 114 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 3b28575e..c7eea222 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -345,8 +345,8 @@ def train(args): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,7 +355,44 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + if args.fused_optimizer_groups: + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + param_group.append(p) + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + print(len(grouped_params)) + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -382,7 +419,11 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -432,10 +473,12 @@ def train(args): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -444,6 +487,36 @@ def train(args): parameter.register_post_accumulate_grad_hook(__grad_hook) + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad() + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -518,6 +591,10 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -596,7 +673,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -614,7 +693,9 @@ def train(args): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -630,11 +711,13 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) - if not args.fused_backward_pass: + if not (args.fused_backward_pass or args.fused_optimizer_groups): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -642,9 +725,14 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -753,7 +841,7 @@ def train(args): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser From 793aeb94da53565fb08c7b0b2538f2ade04824bb Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Tue, 7 May 2024 18:21:31 +0900 Subject: [PATCH 027/356] fix get_trainable_params in controlnet-llite training --- sdxl_train_control_net_lllite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628..6ad6e763 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -477,7 +477,7 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() + params_to_clip = accelerator.unwrap_model(unet).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() From 607e041f3de972f2c3030e7c8b43dfc3c2eb2d65 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 14:16:41 +0900 Subject: [PATCH 028/356] chore: Refactor optimizer group --- sdxl_train.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index c7eea222..be2b7166 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -357,27 +357,37 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + # calculate total number of parameters n_total_params = sum(len(params["params"]) for params in params_to_optimize) params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - # split params into groups + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) grouped_params = [] param_group = [] param_group_lr = -1 for group in params_to_optimize: lr = group["lr"] for p in group["params"]: + # if the learning rate is different for different params, start a new group if lr != param_group_lr: if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = lr + param_group.append(p) + + # if the group has enough parameters, start a new group if len(param_group) == params_per_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = -1 + if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) @@ -388,7 +398,6 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - print(len(grouped_params)) logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: @@ -420,6 +429,7 @@ def train(args): # lr schedulerを用意する if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code else: @@ -472,6 +482,7 @@ def train(args): optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) @@ -488,16 +499,20 @@ def train(args): parameter.register_post_accumulate_grad_hook(__grad_hook) elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + # counters are used to determine when to step the optimizer global optimizer_hooked_count global num_parameters_per_group global parameter_optimizer_map + optimizer_hooked_count = {} num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: @@ -511,7 +526,7 @@ def train(args): optimizer_hooked_count[i] += 1 if optimizer_hooked_count[i] == num_parameters_per_group[i]: optimizers[i].step() - optimizers[i].zero_grad() + optimizers[i].zero_grad(set_to_none=True) parameter.register_post_accumulate_grad_hook(optimizer_hook) parameter_optimizer_map[parameter] = opt_idx @@ -593,7 +608,7 @@ def train(args): current_step.value = global_step if args.fused_optimizer_groups: - optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -725,14 +740,14 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - elif args.fused_optimizer_groups: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() - - lr_scheduler.step() - - if not (args.fused_backward_pass or args.fused_optimizer_groups): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From c1ba0b4356637c881ea99663fcce5943fc33fc56 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 14:21:10 +0900 Subject: [PATCH 029/356] update readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index a7047a36..859a7618 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Working in progress + +- Fixed some bugs when using DeepSpeed. Related [#1247] + + +- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] + + ### Apr 7, 2024 / 2024-04-07: v0.8.7 - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. From f3d2cf22ff9ad49e7f8bd68494714fa3bedbd77d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 15:03:02 +0900 Subject: [PATCH 030/356] update README for fused optimizer --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index 859a7618..4fd97fb2 100644 --- a/README.md +++ b/README.md @@ -139,8 +139,37 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! + - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. + - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. + - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. + +- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. + - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. + - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. + - Fixed some bugs when using DeepSpeed. Related [#1247] +- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 + - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 + - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 + - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。 + +- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 + - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 + - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 + - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] From bee8cee7e8fbeecc05b1c80a1e9e8fadab3210a5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 15:08:52 +0900 Subject: [PATCH 031/356] update README for fused optimizer --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4fd97fb2..9c7ecad9 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. - - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts. - Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. @@ -162,14 +162,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 - - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。 - SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 - - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] From 1ffc0b330aa362a408e46e9a52784d72aa73d263 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 16:18:43 +0900 Subject: [PATCH 032/356] fix typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index e3c0229a..b2de8a21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3093,7 +3093,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", From 3c8193f64269fff68d16c1f38dedfde8715f70bb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 17:00:51 +0900 Subject: [PATCH 033/356] revert lora+ for lora_fa --- networks/lora_fa.py | 104 +++++++++++--------------------------------- 1 file changed, 25 insertions(+), 79 deletions(-) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 58bcb220..919222ce 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -15,10 +15,8 @@ import numpy as np import torch import re from library.utils import setup_logging - setup_logging() import logging - logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -506,15 +504,6 @@ def create_network( if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) - loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) - loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) - loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None - loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None - loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None - if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: - network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) - return network @@ -540,9 +529,7 @@ def get_block_dims_and_alphas( len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning( - f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" - ) + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -816,17 +803,11 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout - self.loraplus_lr_ratio = None - self.loraplus_unet_lr_ratio = None - self.loraplus_text_encoder_lr_ratio = None - if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -834,13 +815,9 @@ class LoRANetwork(torch.nn.Module): logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -962,11 +939,6 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): - self.loraplus_lr_ratio = loraplus_lr_ratio - self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio - self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio - def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -1065,42 +1037,18 @@ class LoRANetwork(torch.nn.Module): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, ratio): - param_groups = {"lora": {}, "plus": {}} - for lora in loras: - for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: - param_groups["plus"][f"{lora.lora_name}.{name}"] = param - else: - param_groups["lora"][f"{lora.lora_name}.{name}"] = param - + def enumerate_params(loras: List[LoRAModule]): params = [] - for key in param_groups.keys(): - param_data = {"params": param_groups[key].values()} - - if len(param_data["params"]) == 0: - continue - - if lr is not None: - if key == "plus": - param_data["lr"] = lr * ratio - else: - param_data["lr"] = lr - - if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: - continue - - params.append(param_data) - + for lora in loras: + # params.extend(lora.parameters()) + params.extend(lora.get_trainable_params()) return params if self.text_encoder_loras: - params = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) if self.unet_loras: if self.block_lr: @@ -1114,20 +1062,21 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - params = assemble_params( - block_loras, - (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - self.loraplus_unet_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(block_loras)} + + if unet_lr is not None: + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + elif default_lr is not None: + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + all_params.append(param_data) else: - params = assemble_params( - self.unet_loras, - unet_lr if unet_lr is not None else default_lr, - self.loraplus_unet_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) return all_params @@ -1144,9 +1093,6 @@ class LoRANetwork(torch.nn.Module): def get_trainable_params(self): return self.parameters() - def get_trainable_named_params(self): - return self.named_parameters() - def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None From 44190416c6389d9ae9ffb18c28744be1259fc02c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 17:01:20 +0900 Subject: [PATCH 034/356] update docs etc. --- README.md | 26 ++++++++++++++++++++++++-- docs/train_network_README-ja.md | 11 +++++++---- networks/lora.py | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9c7ecad9..b10da0f2 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. -- Fixed some bugs when using DeepSpeed. Related [#1247] +- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! + - LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details. + - Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder. + - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. + - `network_module` `networks.lora` and `networks.dylora` are available. + +- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - Specify the learning rate and dim (rank) for each block. + - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). + +- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 @@ -171,7 +182,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 -- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] +- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。 + - LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。 + - `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。 + - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など + - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 + +- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - ブロックごとに学習率および dim (rank) を指定することができます。 + - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 + +- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) ### Apr 7, 2024 / 2024-04-07: v0.8.7 diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 2205a773..46085117 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -181,16 +181,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf 詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。 -SDXLは現在サポートしていません。 - フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 +SDXL では down/up 9 個、middle 3 個の値を指定してください。 + `--network_args` で以下の引数を指定してください。 - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 - - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 + - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個(SDXL では 9 個)の数値を指定します。 - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 -- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 +- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します(SDXL の場合は 3 個)。 - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 - `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 @@ -215,6 +215,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 +SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。 +対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。 + `--network_args` で以下の引数を指定してください。 - `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 diff --git a/networks/lora.py b/networks/lora.py index 6e564557..00d21b0e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -985,7 +985,7 @@ class LoRANetwork(torch.nn.Module): skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - logger.warn( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: From 9ddb4d7a0138722913f6f1a6f1bf30f7ff89bb5b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 17:55:08 +0900 Subject: [PATCH 035/356] update readme and help message etc. --- README.md | 8 ++++++++ library/sdxl_model_util.py | 6 ++++-- library/sdxl_train_util.py | 6 +++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b10da0f2..ed91d6d7 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! + - It seems that the model file loading is faster in the WSL environment etc. + - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 @@ -193,6 +197,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 + - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 + - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index e6fcb1f9..4fad78a1 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -9,8 +9,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionMode from library import model_util from library import sdxl_original_unet from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 @@ -171,8 +173,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - if(disable_mmap): - state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read()) + if disable_mmap: + state_dict = safetensors.torch.load(open(ckpt_path, "rb").read()) else: try: state_dict = load_file(ckpt_path, device=map_location) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 106c5b45..b74bea91 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,6 +5,7 @@ from typing import Optional import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate import init_empty_weights @@ -13,8 +14,10 @@ from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" @@ -44,7 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, - args.disable_mmap_load_safetensors + args.disable_mmap_load_safetensors, ) # work on low-ram device @@ -336,6 +339,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", ) From 3701507874c920e09e402980363702a91a67da3d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 20:56:56 +0900 Subject: [PATCH 036/356] raise original error if error is occured in checking latents --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d157cdbc..8a69f0be 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2136,9 +2136,8 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False except Exception as e: - print(npz_path) - print(e) - return False + logger.error(f"Error loading file: {npz_path}") + raise e return True From 39b82f26e5f9df6518a4e32f4b91b4c46cc667fb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 20:58:45 +0900 Subject: [PATCH 037/356] update readme --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index ed91d6d7..24585341 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. +- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath! + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 @@ -201,6 +203,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 +- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) From 16677da0d90ad9094a0301990b831a8dd6c0e957 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 22:15:07 +0900 Subject: [PATCH 038/356] fix create_network_from_weights doesn't work --- networks/lora.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 00d21b0e..79dc6ec0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -757,6 +757,9 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -792,7 +795,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh ) # block lr - block_lr_weight = parse_block_lr_kwargs(kwargs) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) if block_lr_weight is not None: network.set_block_lr_weight(block_lr_weight) From 589c2aa025d277497de32c2ceb8a9e76f4ca4bf2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 13 May 2024 21:20:37 +0900 Subject: [PATCH 039/356] update README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 24585341..9d042a41 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath! +- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! + +- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO! + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 @@ -205,6 +211,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。 +- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。 + +- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) rockerBOO 氏に感謝します。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) From 153764a687d7553866335554d2b35ba89a123297 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 15 May 2024 20:21:49 +0900 Subject: [PATCH 040/356] add prompt option '--f' for filename --- README.md | 3 +++ gen_img.py | 55 +++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 9d042a41..52d80121 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) +- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. + - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 @@ -219,6 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) +- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。 ### Apr 7, 2024 / 2024-04-07: v0.8.7 diff --git a/gen_img.py b/gen_img.py index 4fe89871..d0a8f814 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple): clip_prompt: str guide_image: Any raw_prompt: str + file_name: Optional[str] class BatchDataExt(NamedTuple): @@ -2316,7 +2317,7 @@ def main(args): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _, _), ( width, height, @@ -2339,6 +2340,7 @@ def main(args): prompts = [] negative_prompts = [] raw_prompts = [] + filenames = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2371,7 +2373,7 @@ def main(args): all_guide_images_are_same = True for i, ( _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename), _, ) in enumerate(batch): prompts.append(prompt) @@ -2379,6 +2381,7 @@ def main(args): seeds.append(seed) clip_prompts.append(clip_prompt) raw_prompts.append(raw_prompt) + filenames.append(filename) if init_image is not None: init_images.append(init_image) @@ -2478,8 +2481,8 @@ def main(args): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames) ): if highres_fix: seed -= 1 # record original seed @@ -2505,17 +2508,23 @@ def main(args): metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-left", str(crop_left)) - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + if filename is not None: + fln = filename else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + if fln.endswith(".webp"): + image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy + else: + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: try: @@ -2562,6 +2571,7 @@ def main(args): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + filename = None if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing @@ -2783,6 +2793,12 @@ def main(args): logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue + m = re.match(r"f (.+)", parg, re.IGNORECASE) + if m: # filename + filename = m.group(1) + logger.info(f"filename: {filename}") + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") @@ -2873,7 +2889,16 @@ def main(args): b1 = BatchData( False, BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + global_step, + prompt, + negative_prompt, + seed, + init_image, + mask_image, + clip_prompt, + guide_image, + raw_prompt, + filename, ), BatchDataExt( width, @@ -2916,7 +2941,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() add_logging_arguments(parser) - + parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" ) From 146edce6934beee050d8e73458dad794449a0ff4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 18 May 2024 11:05:04 +0900 Subject: [PATCH 041/356] support Diffusers' based SDXL LoRA key for inference --- networks/lora.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/networks/lora.py b/networks/lora.py index 79dc6ec0..9f159f5d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: return block_idx +def convert_diffusers_to_sai_if_needed(weights_sd): + # only supports U-Net LoRA modules + + found_up_down_blocks = False + for k in list(weights_sd.keys()): + if "down_blocks" in k: + found_up_down_blocks = True + break + if "up_blocks" in k: + found_up_down_blocks = True + break + if not found_up_down_blocks: + return + + from library.sdxl_model_util import make_unet_conversion_map + + unet_conversion_map = make_unet_conversion_map() + unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + + # # add extra conversion + # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv" + + logger.info(f"Converting LoRA keys from Diffusers to SAI") + lora_unet_prefix = "lora_unet_" + for k in list(weights_sd.keys()): + if not k.startswith(lora_unet_prefix): + continue + + unet_module_name = k[len(lora_unet_prefix) :].split(".")[0] + + # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small + for hf_module_name, sd_module_name in unet_conversion_map.items(): + if hf_module_name in unet_module_name: + new_key = ( + lora_unet_prefix + + unet_module_name.replace(hf_module_name, sd_module_name) + + k[len(lora_unet_prefix) + len(unet_module_name) :] + ) + weights_sd[new_key] = weights_sd.pop(k) + found = True + break + + if not found: + logger.warning(f"Key {k} is not found in unet_conversion_map") + + # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True @@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: weights_sd = torch.load(file, map_location="cpu") + # if keys are Diffusers based, convert to SAI based + convert_diffusers_to_sai_if_needed(weights_sd) + # get dim/alpha mapping modules_dim = {} modules_alpha = {} From 2f19175dfeb98e5ad93a633c79fa846d67210844 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 15:38:37 +0900 Subject: [PATCH 042/356] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 52d80121..b9852e0a 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) -- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. +- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported. - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 @@ -221,7 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) -- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。 +- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。 ### Apr 7, 2024 / 2024-04-07: v0.8.7 From e3ddd1fbbe4e00f49649f5aabd470b9dccf3019d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 16:26:10 +0900 Subject: [PATCH 043/356] update README and format code --- README.md | 4 ++++ sdxl_train_control_net_lllite.py | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b9852e0a..5d035eb6 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO! +- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th! + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported. @@ -219,6 +221,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します - データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) rockerBOO 氏に感謝します。 +- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。 diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6ad6e763..09b6d73b 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -15,6 +15,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -439,7 +440,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -458,7 +461,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From c68baae48033fe9794860518fe052dbf8def905e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 17:21:04 +0900 Subject: [PATCH 044/356] add `--log_config` option to enable/disable output training config --- README.md | 6 ++++++ fine_tune.py | 20 +++++++++++++++----- library/train_util.py | 16 +++++++++++++--- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 11 files changed, 42 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 5d035eb6..cd774459 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! + - Some settings, such as API keys and directory specifications, are not output due to security issues. + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -209,6 +212,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 + - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 diff --git a/fine_tune.py b/fine_tune.py index 77a1a4f3..d865cd2d 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,11 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -354,7 +358,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -368,7 +374,9 @@ def train(args): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -380,7 +388,9 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -471,7 +481,7 @@ def train(args): accelerator.end_training() - if is_main_process and (args.save_state or args.save_state_on_train_end): + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index 84764263..41047147 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3180,6 +3180,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") parser.add_argument( "--noise_offset", @@ -3388,7 +3389,15 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) -def filter_sensitive_args(args: argparse.Namespace): + +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + sensitive_args = ["wandb_api_key", "huggingface_token"] sensitive_path_args = [ "pretrained_model_name_or_path", @@ -3402,9 +3411,9 @@ def filter_sensitive_args(args: argparse.Namespace): ] filtered_args = {} for k, v in vars(args).items(): - # filter out sensitive values + # filter out sensitive values and convert to string if necessary if k not in sensitive_args + sensitive_path_args: - #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): filtered_args[k] = v # accelerate does not support lists @@ -3416,6 +3425,7 @@ def filter_sensitive_args(args: argparse.Namespace): return filtered_args + # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): # if wandb is enabled, the command line is exposed to the public diff --git a/sdxl_train.py b/sdxl_train.py index 4c4e3872..11f9892a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -589,7 +589,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first sdxl_train_util.sample_images( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index b141965f..30131090 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -354,7 +354,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 9490cf6f..292a0463 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index 793f79c7..9994dd99 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -344,7 +344,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_db.py b/train_db.py index 4f901829..a5408cd3 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/train_network.py b/train_network.py index 401a1c70..38e4888e 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,7 @@ class NetworkTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 56a38739..184607d1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ class TextualInversionTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 69178523..8eed00fa 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing From e4d9e3c843f5d9bfbfe56bd44c8f6a04d370201e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 17:46:07 +0900 Subject: [PATCH 045/356] remove dependency for omegaconf #ref 1284 --- README.md | 4 ++++ requirements.txt | 1 - train_controlnet.py | 38 +++++++++++++++++++++++++++++++------- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index cd774459..04769a4c 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! - Some settings, such as API keys and directory specifications, are not output due to security issues. + +- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. @@ -215,6 +217,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 +- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 diff --git a/requirements.txt b/requirements.txt index 9495dab2..e99775b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,6 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 -omegaconf==2.3.0 # for Image utils imagesize==1.4.1 # for BLIP captioning diff --git a/train_controlnet.py b/train_controlnet.py index 3a1fa9de..c9ac6c5a 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -5,7 +5,8 @@ import os import random import time from multiprocessing import Value -from omegaconf import OmegaConf + +# from omegaconf import OmegaConf import toml from tqdm import tqdm @@ -13,6 +14,7 @@ from tqdm import tqdm import torch from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -197,7 +199,23 @@ def train(args): "resnet_time_scale_shift": "default", "projection_class_embeddings_input_dim": None, } - unet.config = OmegaConf.create(unet.config) + # unet.config = OmegaConf.create(unet.config) + + # make unet.config iterable and accessible by attribute + class CustomConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __contains__(self, name): + return name in self.__dict__ + + unet.config = CustomConfig(**unet.config) controlnet = ControlNetModel.from_unet(unet) @@ -230,7 +248,7 @@ def train(args): ) vae.to("cpu") clean_memory_on_device(accelerator.device) - + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -239,7 +257,7 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = controlnet.parameters() + trainable_params = list(controlnet.parameters()) _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -348,7 +366,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -424,7 +444,9 @@ def train(args): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device) + timesteps, huber_c = train_util.get_timesteps_and_huber_c( + args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + ) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -456,7 +478,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From 4c798129b04955caad1c48405de168ff63a3809c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 19:00:32 +0900 Subject: [PATCH 046/356] update README --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 04769a4c..d0f2d65b 100644 --- a/README.md +++ b/README.md @@ -165,11 +165,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath! + - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended. + - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! - Some settings, such as API keys and directory specifications, are not output due to security issues. - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! - + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -214,6 +217,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。 + - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。 + - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 From febc5c59fad74dfcead9064033171a9c674e4870 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 19:03:43 +0900 Subject: [PATCH 047/356] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d0f2d65b..838e4022 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath! - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended. + - When specifying from the command line, use `=` like `--learning_rate=-1e-7`. - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! - Some settings, such as API keys and directory specifications, are not output due to security issues. @@ -219,6 +220,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。 - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。 + - コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。 - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 From db6752901fc204686e460255797b188cb28611a5 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 19 May 2024 19:07:25 +0900 Subject: [PATCH 048/356] =?UTF-8?q?=E7=94=BB=E5=83=8F=E3=81=AE=E3=82=A2?= =?UTF-8?q?=E3=83=AB=E3=83=95=E3=82=A1=E3=83=81=E3=83=A3=E3=83=B3=E3=83=8D?= =?UTF-8?q?=E3=83=AB=E3=82=92loss=E3=81=AE=E3=83=9E=E3=82=B9=E3=82=AF?= =?UTF-8?q?=E3=81=A8=E3=81=97=E3=81=A6=E4=BD=BF=E7=94=A8=E3=81=99=E3=82=8B?= =?UTF-8?q?=E3=82=AA=E3=83=97=E3=82=B7=E3=83=A7=E3=83=B3=E3=82=92=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0=20(#1223)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add alpha_mask parameter and apply masked loss * Fix type hint in trim_and_resize_if_required function * Refactor code to use keyword arguments in train_util.py * Fix alpha mask flipping logic * Fix alpha mask initialization * Fix alpha_mask transformation * Cache alpha_mask * Update alpha_masks to be on CPU * Set flipped_alpha_masks to Null if option disabled * Check if alpha_mask is None * Set alpha_mask to None if option disabled * Add description of alpha_mask option to docs --- docs/train_network_README-ja.md | 2 + docs/train_network_README-zh.md | 2 + library/config_util.py | 2 + library/custom_train_functions.py | 5 +- library/train_util.py | 203 ++++++++++++------------------ sdxl_train.py | 4 +- train_db.py | 4 +- train_network.py | 4 +- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 4 +- 10 files changed, 105 insertions(+), 129 deletions(-) diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 46085117..55c80c4b 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。 * `--network_args` * 複数の引数を指定できます。後述します。 +* `--alpha_mask` + * 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 diff --git a/docs/train_network_README-zh.md b/docs/train_network_README-zh.md index ed7a0c4e..830014f7 100644 --- a/docs/train_network_README-zh.md +++ b/docs/train_network_README-zh.md @@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中 * 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。 * `--network_args` * 可以指定多个参数。将在下面详细说明。 +* `--alpha_mask` + * 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) 当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。 diff --git a/library/config_util.py b/library/config_util.py index 59f5f86d..82baab83 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,6 +78,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + alpha_mask: bool = False @dataclass @@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask}, """ ), " ", diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 406e0e36..fad12740 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, mask_image): # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + # mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image.to(dtype=loss.dtype) # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") diff --git a/library/train_util.py b/library/train_util.py index 41047147..20f8055d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,6 +159,9 @@ class ImageInfo: self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None + self.alpha_mask_flipped: Optional[torch.Tensor] = None + self.use_alpha_mask: bool = False class BucketManager: @@ -379,6 +382,7 @@ class BaseSubset: caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + alpha_mask: bool, ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -403,6 +407,7 @@ class BaseSubset: self.img_count = 0 + self.alpha_mask = alpha_mask class DreamBoothSubset(BaseSubset): def __init__( @@ -412,47 +417,13 @@ class DreamBoothSubset(BaseSubset): class_tokens: Optional[str], caption_extension: str, cache_info: bool, - num_repeats, - shuffle_caption, - caption_separator: str, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.is_reg = is_reg @@ -473,47 +444,13 @@ class FineTuningSubset(BaseSubset): self, image_dir, metadata_file: str, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.metadata_file = metadata_file @@ -531,47 +468,13 @@ class ControlNetSubset(BaseSubset): conditioning_data_dir: str, caption_extension: str, cache_info: bool, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.conditioning_data_dir = conditioning_data_dir @@ -985,6 +888,8 @@ class BaseDataset(torch.utils.data.Dataset): for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] + info.use_alpha_mask = subset.alpha_mask + if info.latents_npz is not None: # fine tuning dataset continue @@ -1088,8 +993,8 @@ class BaseDataset(torch.utils.data.Dataset): def get_image_size(self, image_path): return imagesize.get(image_path) - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1166,6 +1071,7 @@ class BaseDataset(torch.utils.data.Dataset): input_ids_list = [] input_ids2_list = [] latents_list = [] + alpha_mask_list = [] images = [] original_sizes_hw = [] crop_top_lefts = [] @@ -1190,21 +1096,27 @@ class BaseDataset(torch.utils.data.Dataset): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped - + alpha_mask = image_info.alpha_mask_flipped + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents + alpha_mask = flipped_alpha_mask del flipped_latents + del flipped_alpha_mask latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1241,11 +1153,22 @@ class BaseDataset(torch.utils.data.Dataset): if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [W,H] + else: + alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] + alpha_mask = transforms.ToTensor()(alpha_mask) + else: + alpha_mask = None + img = img[:, :, :3] # remove alpha channel + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる images.append(image) latents_list.append(latents) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1348,6 +1271,8 @@ class BaseDataset(torch.utils.data.Dataset): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -2145,7 +2070,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2154,13 +2079,19 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + if flipped_alpha_mask is not None: + kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2349,17 +2280,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path): +def load_image(image_path, alpha=False): image = Image.open(image_path) if not image.mode == "RGB": - image = image.convert("RGB") + if alpha: + image = image.convert("RGBA") + else: + image = image.convert("RGB") img = np.array(image, np.uint8) return img # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2403,10 +2337,18 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + alpha_masks = [] for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + if info.use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [W,H] + image = image[:, :, :3] + else: + alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] + alpha_masks.append(transforms.ToTensor()(alpha_mask)) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2419,25 +2361,37 @@ def cache_batch_latents( with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if info.use_alpha_mask: + alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu") + else: + alpha_masks = [None] * len(image_infos) + flipped_alpha_masks = [None] * len(image_infos) + if flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) with torch.no_grad(): flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if info.use_alpha_mask: + flipped_alpha_masks = torch.flip(alpha_masks, dims=[3]) else: flipped_latents = [None] * len(latents) + flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + info.alpha_mask_flipped = flipped_alpha_mask + if not HIGH_VRAM: clean_memory_on_device(vae.device) @@ -3683,6 +3637,11 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) parser.add_argument( "--dataset_class", diff --git a/sdxl_train.py b/sdxl_train.py index 7c71a513..dcd06766 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -712,7 +712,9 @@ def train(args): noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: diff --git a/train_db.py b/train_db.py index a5408cd3..c4690000 100644 --- a/train_db.py +++ b/train_db.py @@ -360,7 +360,9 @@ def train(args): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_network.py b/train_network.py index 38e4888e..cd1677ad 100644 --- a/train_network.py +++ b/train_network.py @@ -903,7 +903,9 @@ class NetworkTrainer: noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 184607d1..a9c2a109 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -590,7 +590,9 @@ class TextualInversionTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8eed00fa..959839cb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -475,7 +475,9 @@ def train(args): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From f2dd43e198f4bc059f4790ada041fa8f2a305f25 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 19:23:59 +0900 Subject: [PATCH 049/356] revert kwargs to explicit declaration --- library/train_util.py | 158 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 142 insertions(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 20f8055d..6cf28590 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -409,6 +409,7 @@ class BaseSubset: self.alpha_mask = alpha_mask + class DreamBoothSubset(BaseSubset): def __init__( self, @@ -417,13 +418,47 @@ class DreamBoothSubset(BaseSubset): class_tokens: Optional[str], caption_extension: str, cache_info: bool, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator: str, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -444,13 +479,47 @@ class FineTuningSubset(BaseSubset): self, image_dir, metadata_file: str, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -468,13 +537,47 @@ class ControlNetSubset(BaseSubset): conditioning_data_dir: str, caption_extension: str, cache_info: bool, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.conditioning_data_dir = conditioning_data_dir @@ -1100,10 +1203,12 @@ class BaseDataset(torch.utils.data.Dataset): else: latents = image_info.latents_flipped alpha_mask = image_info.alpha_mask_flipped - + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk( + image_info.latents_npz + ) if flipped: latents = flipped_latents alpha_mask = flipped_alpha_mask @@ -1116,7 +1221,9 @@ class BaseDataset(torch.utils.data.Dataset): image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1157,7 +1264,7 @@ class BaseDataset(torch.utils.data.Dataset): if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [W,H] else: - alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] + alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] alpha_mask = transforms.ToTensor()(alpha_mask) else: alpha_mask = None @@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> Tuple[ + Optional[torch.Tensor], + Optional[List[int]], + Optional[List[int]], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2084,7 +2198,9 @@ def load_latents_from_disk( return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): +def save_latents_to_disk( + npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None +): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() @@ -2344,10 +2460,10 @@ def cache_batch_latents( image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) if info.use_alpha_mask: if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [W,H] + alpha_mask = image[:, :, 3] # [W,H] image = image[:, :, :3] else: - alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] + alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] alpha_masks.append(transforms.ToTensor()(alpha_mask)) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2377,13 +2493,23 @@ def cache_batch_latents( flipped_latents = [None] * len(latents) flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): + for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip( + image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks + ): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) + save_latents_to_disk( + info.latents_npz, + latent, + info.latents_original_size, + info.latents_crop_ltrb, + flipped_latent, + alpha_mask, + flipped_alpha_mask, + ) else: info.latents = latent if flip_aug: From da6fea3d9779970a1c573bf26fe37c924efc68d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 21:26:18 +0900 Subject: [PATCH 050/356] simplify and update alpha mask to work with various cases --- finetune/prepare_buckets_latents.py | 33 ++++-- library/config_util.py | 2 + library/custom_train_functions.py | 15 ++- library/train_util.py | 149 +++++++++++++++------------- sdxl_train.py | 6 +- tools/cache_latents.py | 12 ++- train_db.py | 6 +- train_network.py | 10 +- train_textual_inversion.py | 6 +- train_textual_inversion_XTI.py | 6 +- 10 files changed, 140 insertions(+), 105 deletions(-) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 0389da38..019c737a 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -11,6 +11,7 @@ import cv2 import torch from library.device_utils import init_ipex, get_preferred_device + init_ipex() from torchvision import transforms @@ -18,8 +19,10 @@ from torchvision import transforms import library.model_util as model_util import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) DEVICE = get_preferred_device() @@ -89,7 +92,9 @@ def main(args): # bucketのサイズを計算する max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + assert ( + len(max_reso) == 2 + ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" bucket_manager = train_util.BucketManager( args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps @@ -107,7 +112,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument( + "--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", @@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser: help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument( "--full_path", @@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser: help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", ) parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + "--flip_aug", + action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する", + ) + parser.add_argument( + "--alpha_mask", + type=str, + default="", + help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する", ) parser.add_argument( "--skip_existing", diff --git a/library/config_util.py b/library/config_util.py index 82baab83..964270db 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -214,11 +214,13 @@ class ConfigSanitizer: DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, "is_reg": bool, + "alpha_mask": bool, } # FT means FineTuning FT_SUBSET_DISTINCT_SCHEMA = { Required("metadata_file"): str, "image_dir": str, + "alpha_mask": bool, } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index fad12740..af5813a1 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, mask_image): - # mask image is -1 to 1. we need to convert it to 0 to 1 - # mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel - mask_image = mask_image.to(dtype=loss.dtype) +def apply_masked_loss(loss, batch): + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype) + else: + return loss # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 loss = loss * mask_image return loss diff --git a/library/train_util.py b/library/train_util.py index 6cf28590..e7a50f04 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,9 +159,7 @@ class ImageInfo: self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None - self.alpha_mask: Optional[torch.Tensor] = None - self.alpha_mask_flipped: Optional[torch.Tensor] = None - self.use_alpha_mask: bool = False + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime class BucketManager: @@ -364,6 +362,7 @@ class BaseSubset: def __init__( self, image_dir: Optional[str], + alpha_mask: Optional[bool], num_repeats: int, shuffle_caption: bool, caption_separator: str, @@ -382,9 +381,9 @@ class BaseSubset: caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], - alpha_mask: bool, ) -> None: self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator @@ -407,8 +406,6 @@ class BaseSubset: self.img_count = 0 - self.alpha_mask = alpha_mask - class DreamBoothSubset(BaseSubset): def __init__( @@ -418,6 +415,7 @@ class DreamBoothSubset(BaseSubset): class_tokens: Optional[str], caption_extension: str, cache_info: bool, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -441,6 +439,7 @@ class DreamBoothSubset(BaseSubset): super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -479,6 +478,7 @@ class FineTuningSubset(BaseSubset): self, image_dir, metadata_file: str, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator, @@ -502,6 +502,7 @@ class FineTuningSubset(BaseSubset): super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -921,7 +922,7 @@ class BaseDataset(torch.utils.data.Dataset): logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -991,8 +992,6 @@ class BaseDataset(torch.utils.data.Dataset): for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] - info.use_alpha_mask = subset.alpha_mask - if info.latents_npz is not None: # fine tuning dataset continue @@ -1002,7 +1001,9 @@ class BaseDataset(torch.utils.data.Dataset): if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) if cache_available: # do not add to batch continue @@ -1028,7 +1029,7 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1202,18 +1203,15 @@ class BaseDataset(torch.utils.data.Dataset): alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped - alpha_mask = image_info.alpha_mask_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk( - image_info.latents_npz - ) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents - alpha_mask = flipped_alpha_mask + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem del flipped_latents - del flipped_alpha_mask latents = torch.FloatTensor(latents) if alpha_mask is not None: alpha_mask = torch.FloatTensor(alpha_mask) @@ -1255,23 +1253,28 @@ class BaseDataset(torch.utils.data.Dataset): # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: - img = aug(image=img)["image"] + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem if subset.alpha_mask: if img.shape[2] == 4: - alpha_mask = img[:, :, 3] # [W,H] + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 else: - alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] - alpha_mask = transforms.ToTensor()(alpha_mask) + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: alpha_mask = None + img = img[:, :, :3] # remove alpha channel latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img images.append(image) latents_list.append(latents) @@ -1361,6 +1364,23 @@ class BaseDataset(torch.utils.data.Dataset): example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) + else: + example["alpha_masks"] = torch.stack(alpha_mask_list) + if images[0] is not None: images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() @@ -1378,8 +1398,6 @@ class BaseDataset(torch.utils.data.Dataset): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) - example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None - if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1393,6 +1411,7 @@ class BaseDataset(torch.utils.data.Dataset): resized_sizes = [] bucket_reso = None flip_aug = None + alpha_mask = None random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1401,10 +1420,13 @@ class BaseDataset(torch.utils.data.Dataset): if flip_aug is None: flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask random_crop = subset.random_crop bucket_reso = image_info.bucket_reso else: + # TODO そもそも混在してても動くようにしたほうがいい assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" @@ -1441,6 +1463,7 @@ class BaseDataset(torch.utils.data.Dataset): example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask example["random_crop"] = random_crop example["bucket_reso"] = bucket_reso return example @@ -2149,7 +2172,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): @@ -2167,6 +2190,12 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != reso: # HxW + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2177,14 +2206,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[ - Optional[torch.Tensor], - Optional[List[int]], - Optional[List[int]], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], -]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2194,20 +2216,15 @@ def load_latents_from_disk( crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk( - npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None -): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - if flipped_alpha_mask is not None: - kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy() + kwargs["alpha_mask"] = alpha_mask # ndarray np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2398,10 +2415,11 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: def load_image(image_path, alpha=False): image = Image.open(image_path) - if not image.mode == "RGB": - if alpha: + if alpha: + if not image.mode == "RGBA": image = image.convert("RGBA") - else: + else: + if not image.mode == "RGB": image = image.convert("RGB") img = np.array(image, np.uint8) return img @@ -2441,7 +2459,7 @@ def trim_and_resize_if_required( def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2453,49 +2471,43 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] - alpha_masks = [] + alpha_masks: List[np.ndarray] = [] for info in image_infos: - image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - if info.use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [W,H] - image = image[:, :, :3] - else: - alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] - alpha_masks.append(transforms.ToTensor()(alpha_mask)) - image = IMAGE_TRANSFORMS(image) - images.append(image) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + else: + alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - if info.use_alpha_mask: - alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu") - else: - alpha_masks = [None] * len(image_infos) - flipped_alpha_masks = [None] * len(image_infos) - if flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) with torch.no_grad(): flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - if info.use_alpha_mask: - flipped_alpha_masks = torch.flip(alpha_masks, dims=[3]) else: flipped_latents = [None] * len(latents) - flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip( - image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks - ): + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") @@ -2508,15 +2520,12 @@ def cache_batch_latents( info.latents_crop_ltrb, flipped_latent, alpha_mask, - flipped_alpha_mask, ) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - info.alpha_mask_flipped = flipped_alpha_mask if not HIGH_VRAM: clean_memory_on_device(vae.device) diff --git a/sdxl_train.py b/sdxl_train.py index dcd06766..9e20c60c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -711,10 +711,8 @@ def train(args): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f..b7c88121 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -17,10 +17,13 @@ from library.config_util import ( BlueprintGenerator, ) from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -107,7 +110,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: else: _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) @@ -136,6 +139,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: b_size = len(batch["images"]) vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size flip_aug = batch["flip_aug"] + alpha_mask = batch["alpha_mask"] random_crop = batch["random_crop"] bucket_reso = batch["bucket_reso"] @@ -154,14 +158,16 @@ def cache_to_disk(args: argparse.Namespace) -> None: image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + if train_util.is_disk_cached_latents_is_expected( + image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask + ): logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") diff --git a/train_db.py b/train_db.py index c4690000..39d8ea6e 100644 --- a/train_db.py +++ b/train_db.py @@ -359,10 +359,8 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_network.py b/train_network.py index cd1677ad..b272a6e1 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,9 @@ class NetworkTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -902,10 +904,8 @@ class NetworkTrainer: loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index a9c2a109..ade077c3 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -589,10 +589,8 @@ class TextualInversionTrainer: target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 959839cb..efb59137 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -474,10 +474,8 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From 00513b9b7066fc1307fbe26ad13ed39f3bceceb0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 May 2024 22:27:12 -0400 Subject: [PATCH 051/356] Add LoRA+ LR Ratio info message to logger --- networks/dylora.py | 3 +++ networks/lora.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/networks/dylora.py b/networks/dylora.py index d57e3d58..b0925453 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -368,6 +368,9 @@ class DyLoRANetwork(torch.nn.Module): self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: diff --git a/networks/lora.py b/networks/lora.py index 9f159f5d..82b8b5b4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1134,6 +1134,9 @@ class LoRANetwork(torch.nn.Module): self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) From e8cfd4ba1d4734c4dd37c9b5fdc0633378879d9b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 May 2024 22:01:37 +0900 Subject: [PATCH 052/356] fix to work cond mask and alpha mask --- library/config_util.py | 3 ++- library/custom_train_functions.py | 4 +++- library/train_util.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 964270db..10b2457f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,7 +78,6 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 - alpha_mask: bool = False @dataclass @@ -87,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams): class_tokens: Optional[str] = None caption_extension: str = ".caption" cache_info: bool = False + alpha_mask: bool = False @dataclass class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None + alpha_mask: bool = False @dataclass diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index af5813a1..2a513dc5 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -484,9 +484,11 @@ def apply_masked_loss(loss, batch): # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype) + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss diff --git a/library/train_util.py b/library/train_util.py index e7a50f04..1f9f3c5d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -561,6 +561,7 @@ class ControlNetSubset(BaseSubset): super().__init__( image_dir, + False, # alpha_mask num_repeats, shuffle_caption, caption_separator, @@ -1947,6 +1948,7 @@ class ControlNetDataset(BaseDataset): None, subset.caption_extension, subset.cache_info, + False, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph return False if npz["alpha_mask"].shape[0:2] != reso: # HxW return False + else: + if "alpha_mask" in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False): if os.name == "nt": cv2.imshow("cond_img", cond_img) + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() From d50c1b3c5cfd590e43e832272a77bf8c84d371dd Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Mon, 27 May 2024 01:11:01 -0400 Subject: [PATCH 053/356] Update issue link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 838e4022..23e04935 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。 -- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) rockerBOO 氏に感謝します。 +- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。 - ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。 From a4c3155148e667f5235c2e3df52bad7fd8f95dc4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 May 2024 20:59:40 +0900 Subject: [PATCH 054/356] add doc for mask loss --- docs/masked_loss_README-ja.md | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 docs/masked_loss_README-ja.md diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md new file mode 100644 index 00000000..86053224 --- /dev/null +++ b/docs/masked_loss_README-ja.md @@ -0,0 +1,40 @@ +## マスクロスについて + +マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。 +たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。 + +マスクロスのマスクには、二種類の指定方法があります。 + +- マスク画像を用いる方法 +- 透明度(アルファチャネル)を使用する方法 + +なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。 + +### マスク画像を用いる方法 + +学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 + +マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 + +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存するしてください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 + +### 透明度(アルファチャネル)を使用する方法 + +学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 + +学習時のスクリプトのオプション `--alpha_mask`、または dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## 学習時の注意事項 + +- 現時点では DreamBooth 方式の dataset のみ対応しています。 +- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。 +- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証) +- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。 From 71ad3c0f45ba64bd5dc069addc8ef0fa94bf4e19 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 27 May 2024 21:07:57 +0900 Subject: [PATCH 055/356] Update masked_loss_README-ja.md add sample images --- docs/masked_loss_README-ja.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md index 86053224..5377a5af 100644 --- a/docs/masked_loss_README-ja.md +++ b/docs/masked_loss_README-ja.md @@ -14,6 +14,11 @@ 学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 +- 学習画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- マスク画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存するしてください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 @@ -22,7 +27,11 @@ DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディ 学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 -学習時のスクリプトのオプション `--alpha_mask`、または dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※それぞれの画像は透過PNG + +学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 ```toml [[datasets.subsets]] From fc85496f7e99b2bbbbd0246e0b0521780c55d859 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 May 2024 21:25:06 +0900 Subject: [PATCH 056/356] update docs for masked loss --- README.md | 8 +++++ docs/masked_loss_README-ja.md | 10 ++++++- docs/masked_loss_README.md | 56 +++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 docs/masked_loss_README.md diff --git a/README.md b/README.md index 23e04935..52c96339 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. - `network_module` `networks.lora` and `networks.dylora` are available. +- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru! + - The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file. + - See [About masked loss](./docs/masked_loss_README.md) for details. + - LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). @@ -214,6 +218,10 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 +- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。 + - 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。 + - 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。 + - SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md index 5377a5af..58f042c3 100644 --- a/docs/masked_loss_README-ja.md +++ b/docs/masked_loss_README-ja.md @@ -19,9 +19,17 @@ - マスク画像 ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 -DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存するしてください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 ### 透明度(アルファチャネル)を使用する方法 diff --git a/docs/masked_loss_README.md b/docs/masked_loss_README.md new file mode 100644 index 00000000..3ac5ad21 --- /dev/null +++ b/docs/masked_loss_README.md @@ -0,0 +1,56 @@ +## Masked Loss + +Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background. + +There are two ways to specify the mask for masked loss. + +- Using a mask image +- Using transparency (alpha channel) of the image + +The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html). + +### Using a mask image + +This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image. + +- Training image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- Mask image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + +The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently. + +Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details. + +### Using transparency (alpha channel) of the image + +The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5). + +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※Each image is a transparent PNG + +Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this. + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## Notes on training + +- At the moment, only the dataset in the DreamBooth method is supported. +- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary. +- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified) +- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched. From b2363f1021955c049c98e65676efca130690c40f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:20:20 +0800 Subject: [PATCH 057/356] Final implementation --- library/train_util.py | 11 ++++- train_network.py | 104 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5d..beb33bf8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -657,8 +657,15 @@ class BaseDataset(torch.utils.data.Dataset): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step diff --git a/train_network.py b/train_network.py index b272a6e1..76e6cd8a 100644 --- a/train_network.py +++ b/train_network.py @@ -493,17 +493,24 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 - if accelerator.is_main_process or args.deepspeed: + if accelerator.is_main_process: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - if len(weights) > i: - weights.pop(i) + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -514,6 +521,15 @@ class NetworkTrainer: models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -757,7 +773,53 @@ class NetworkTrainer: if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= args.gradient_accumulation_steps + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + initial_step = 0 # do not skip + global_step = 0 noise_scheduler = DDPMScheduler( @@ -816,7 +878,11 @@ class NetworkTrainer: self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for epoch in range(num_train_epochs): + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -824,7 +890,12 @@ class NetworkTrainer: accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - for step, batch in enumerate(train_dataloader): + skipped_dataloader = None + if initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -1126,6 +1197,25 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 3eb27ced52e8bf522c7e490c3dacba1f8597f5b1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:24:15 +0800 Subject: [PATCH 058/356] Skip the final 1 step --- train_network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_network.py b/train_network.py index 76e6cd8a..d1f02d53 100644 --- a/train_network.py +++ b/train_network.py @@ -897,6 +897,10 @@ class NetworkTrainer: for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) From e5bab69e3a8f3dc4afb1badba65b6c50ca2f36d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Jun 2024 21:11:40 +0900 Subject: [PATCH 059/356] fix alpha mask without disk cache closes #1351, ref #1339 --- library/train_util.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5d..566f5927 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1265,7 +1265,8 @@ class BaseDataset(torch.utils.data.Dataset): if subset.alpha_mask: if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [H,W] - alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) else: alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: @@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask # ndarray + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2496,8 +2497,9 @@ def cache_batch_latents( if image.shape[2] == 4: alpha_mask = image[:, :, 3] # [H,W] alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] else: - alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] else: alpha_mask = None alpha_masks.append(alpha_mask) From 4dbcef429b744d0cc101494802448b8c15f4f674 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Jun 2024 21:26:55 +0900 Subject: [PATCH 060/356] update for corner cases --- library/train_util.py | 3 +++ train_network.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 102f9f03..4736ff4f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -663,6 +663,7 @@ class BaseDataset(torch.utils.data.Dataset): for _ in range(num_epochs): self.current_epoch += 1 self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? else: logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) self.current_epoch = epoch @@ -5560,6 +5561,8 @@ class LossRecorder: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/train_network.py b/train_network.py index d1f02d53..7ba07385 100644 --- a/train_network.py +++ b/train_network.py @@ -493,13 +493,15 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") # save current ecpoch and step @@ -813,11 +815,12 @@ class NetworkTrainer: ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) else: # if not, only epoch no is skipped for informative purpose - epoch_to_start = initial_step // math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip global_step = 0 @@ -878,9 +881,11 @@ class NetworkTrainer: self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -892,7 +897,7 @@ class NetworkTrainer: skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): From 4ecbac131aba3d121f9708b3ac2a1f4726b17dc0 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 5 Jun 2024 16:31:44 +0900 Subject: [PATCH 061/356] Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307) --- .github/workflows/typos.yml | 2 +- _typos.toml | 2 ++ library/ipex/attention.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483..c81ff321 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.21.0 diff --git a/_typos.toml b/_typos.toml index ae9e06b1..bbf7728f 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,6 +2,7 @@ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started [default.extend-identifiers] +ddPn08="ddPn08" [default.extend-words] NIN="NIN" @@ -27,6 +28,7 @@ rik="rik" koo="koo" yos="yos" wn="wn" +hime="hime" [files] diff --git a/library/ipex/attention.py b/library/ipex/attention.py index d989ad53..2bc62f65 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -5,7 +5,7 @@ from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) From 58fb64819ab117e2b7bca6e87bae28901b616860 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:09 +0900 Subject: [PATCH 062/356] set static graph flag when DDP ref #1363 --- sdxl_train_control_net_lllite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 30131090..5ff060a9 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -289,6 +289,9 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if isinstance(unet, DDP): + unet._set_static_graph() # avoid error for multiple use of the parameter + if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: From 1a104dc75ee5733af8ba17cc9778b39e26673734 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:36 +0900 Subject: [PATCH 063/356] make forward/backward pathes same ref #1363 --- networks/control_net_lllite_for_train.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 65b3520c..366451b7 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -7,8 +7,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied @@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR): add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) self.cond_image = None - self.cond_emb = None def set_cond_image(self, cond_image): self.cond_image = cond_image - self.cond_emb = None def forward(self, x): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible # reshape / b,c,h,w -> b,h*w,c n, c, h, w = cx.shape @@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) cx = torch.cat([cx, self.down(x)], dim=1) cx = self.mid(cx) From 18d7597b0b39cc2204dfbdfdcbf0fead97414be1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jun 2024 19:51:30 +0900 Subject: [PATCH 064/356] update README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 52c96339..25aba639 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! +- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf! + - This resolves the issue where the order of data loading from DataSet changes when resuming training. + - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before). + - If `--resume` is specified, the step saved in the state is used. + - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`). + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 +- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。 + - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。 + - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます) + - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。 + - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 From 56bb81c9e6483b8b4d5b83639548855b8359f4b4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Jun 2024 21:39:35 +0900 Subject: [PATCH 065/356] add grad_hook after restore state closes #1344 --- sdxl_train.py | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 9e20c60c..ae92d6a3 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -481,6 +481,26 @@ def train(args): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused @@ -532,26 +552,6 @@ def train(args): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. - # -> But we think it's ok to patch accelerator even if deepspeed is enabled. - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -589,7 +589,11 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first sdxl_train_util.sample_images( From e5268286bf90ddcc53ad1deb31aba857cfa967d5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 15 Jun 2024 22:20:24 +0900 Subject: [PATCH 066/356] add sd3 models and inference script --- library/sd3_models.py | 1796 ++++++++++++++++++++++++++++++++++++++ library/sd3_utils.py | 113 +++ sd3_minimal_inference.py | 347 ++++++++ 3 files changed, 2256 insertions(+) create mode 100644 library/sd3_models.py create mode 100644 library/sd3_utils.py create mode 100644 sd3_minimal_inference.py diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 00000000..294a69b0 --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,1796 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from functools import partial +import math +from typing import Dict, Optional +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTokenizer, T5TokenizerFast + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region tokenizer +class SDTokenizer: + def __init__( + self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None + ): + """ + サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 + Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. + """ + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. + The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + """ + ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 + 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ + """ + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self, t5xxl=True): + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() if t5xxl else None + + def tokenize_with_weights(self, text: str): + return ( + self.clip_l.tokenize_with_weights(text), + self.clip_g.tokenize_with_weights(text), + self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ) + + +# endregion + +# region mmdit + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: str = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +class SelfAttention(AttentionLinears): + def __init__(self, dim, num_heads=8, mode="xformers"): + super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) + assert mode in MEMORY_LAYOUTS + self.head_dim = dim // num_heads + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x): + q, k, v = self.pre_attention(x) + attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) + return self.post_attention(attn_score) + + +class TransformerBlock(nn.Module): + def __init__(self, context_size, mode="xformers"): + super().__init__() + self.context_size = context_size + self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention(context_size, mode=mode) + self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + self.mlp = MLP( + in_features=context_size, + hidden_features=context_size * 4, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, context_size, num_layers, mode="xformers"): + super().__init__() + self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.norm(x) + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + pre_only=pre_only, + qk_norm=qk_norm, + ) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + ( + scale_msa, + gate_msa, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation( + c + ).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, ( + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) + else: + if not self.scale_mod_only: + ( + shift_msa, + scale_msa, + ) = self.adaLN_modulation( + c + ).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + context_processor_layers=None, + context_size=4096, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_processor_layers is not None: + self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + else: + self.context_processor = None + + self.context_embedder = nn.Linear(context_size, self.hidden_size) + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def cropped_pos_embed(self, h, w, device=None): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + + if self.context_processor is not None: + context = self.context_processor(context) + + B, C, H, W = x.shape + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + ( + einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + default(context, torch.Tensor([]).type_as(x)), + ), + 1, + ) + + for block in self.joint_blocks: + context, x = block(context, x, c) + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_mmdit_sd3_medium_configs(attn_mode: str): + # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, + # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': + # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=192, + patch_size=2, + in_channels=16, + adm_in_channels=2048, + depth=24, + mlp_ratio=4, + qk_norm=None, + num_patches=36864, + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit + + +# endregion + +# region VAE + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + +# endregion + + +# region Text Encoder +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.attn_mode = mode + + def set_attn_mode(self, mode): + self.attn_mode = mode + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), + # "gelu": torch.nn.functional.gelu, + "gelu": lambda: nn.GELU(), +} + + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + # self.mlp = Mlp( + # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device + # ) + self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) + self.mlp.to(device=device, dtype=dtype) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): + super().__init__() + self.layers = torch.nn.ModuleList( + [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + + if x.dtype == torch.bfloat16: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = causal_mask.to(dtype=x.dtype) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + + x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + out, pooled = self([tokens]) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cpu", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + + def set_attn_mode(self, mode): + raise NotImplementedError("This model does not support setting the attention mode") + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device + tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer( + tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state + ) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) + + def set_attn_mode(self, mode): + clip_text_model = self.transformer.text_model + for layer in clip_text_model.encoder.layers: + layer.self_attn.set_attn_mode(mode) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + + def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"end": 1, "pad": 0}, + model_class=T5, + ) + + def set_attn_mode(self, mode): + t5: T5 = self.transformer + for t5block in t5.encoder.block: + t5block: T5Block + t5layer: T5LayerSelfAttention = t5block.layer[0] + t5SaSa: T5Attention = t5layer.SelfAttention + t5SaSa.set_attn_mode(mode) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) + + self.attn_mode = "xformers" # TODO 何とかする + + def set_attn_mode(self, mode): + self.attn_mode = mode + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): + super().__init__() + self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) + self.block = torch.nn.ModuleList( + [ + T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) + for i in range(num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + # print(i, x.mean(), x.std()) + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + # print(x.mean(), x.std()) + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + # print(x.mean(), x.std()) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack( + self.num_layers, + config_dict["d_model"], + config_dict["d_model"], + config_dict["d_ff"], + config_dict["num_heads"], + config_dict["vocab_size"], + dtype, + device, + ) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) + + +def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + } + with torch.no_grad(): + clip_l = SDClipModel( + layer="hidden", + layer_idx=-2, + device=device, + dtype=dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=CLIPL_CONFIG, + ) + if state_dict is not None: + # update state_dict if provided to include logit_scale and text_projection.weight avoid errors + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_l.logit_scale + if "transformer.text_projection.weight" not in state_dict: + state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight + return clip_l + + +def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): + r""" + state_dict is not loaded, but updated with missing keys + """ + CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + } + with torch.no_grad(): + clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = clip_g.logit_scale + return clip_g + + +def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: + T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} + with torch.no_grad(): + t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) + if state_dict is not None: + if "logit_scale" not in state_dict: + state_dict["logit_scale"] = t5.logit_scale + if "transformer.shared.weight" in state_dict: + state_dict.pop("transformer.shared.weight") + return t5 + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 00000000..6f8c361f --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,113 @@ +import math +from typing import Dict +import torch + +from library import sd3_models + + +def get_cond( + prompt: str, + tokenizer: sd3_models.SD3Tokenizer, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: sd3_models.T5XXLModel, +): + l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + l_out, l_pooled = clip_l.encode_token_weights(l_tokens) + g_out, g_pooled = clip_g.encode_token_weights(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + if t5_tokens is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device) + else: + t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + t5_out = t5_out.to(lg_out.dtype) + + return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + + +# used if other sd3 models is available +r""" +def get_sd3_configs(state_dict: Dict): + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + # prefix = "model.diffusion_model." + prefix = "" + + patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] + depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[prefix + "pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[prefix + "context_embedder.weight"].shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, + } + return { + "patch_size": patch_size, + "depth": depth, + "num_patches": num_patches, + "pos_embed_max_size": pos_embed_max_size, + "adm_in_channels": adm_in_channels, + "context_embedder": context_embedder_config, + } + + +def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): + "" + Doesn't load state dict. + "" + sd3_configs = get_sd3_configs(state_dict) + + mmdit = sd3_models.MMDiT( + input_size=None, + pos_embed_max_size=sd3_configs["pos_embed_max_size"], + patch_size=sd3_configs["patch_size"], + in_channels=16, + adm_in_channels=sd3_configs["adm_in_channels"], + depth=sd3_configs["depth"], + mlp_ratio=4, + qk_norm=None, + num_patches=sd3_configs["num_patches"], + context_size=4096, + attn_mode=attn_mode, + ) + return mmdit +""" + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py new file mode 100644 index 00000000..e14f784d --- /dev/null +++ b/sd3_minimal_inference.py @@ -0,0 +1,347 @@ +# Minimum Inference Code for SD3 + +import argparse +import datetime +import math +import os +import random +from typing import Optional, Tuple +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image + +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils + + +def get_noise(seed, latent): + generator = torch.manual_seed(seed) + return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) + + +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + initial_latent: Optional[torch.Tensor], + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + if initial_latent is None: + latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + else: + latent = initial_latent + + latent = latent.to(dtype).to(device) + + noise = get_noise(seed, latent).to(device) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + + sigmas = get_sigmas(model_sampling, steps).to(device) + # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i + + # conditioning = fix_cond(conditioning) + # neg_cond = fix_cond(neg_cond) + # extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale} + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + latent = x + scale_factor = 1.5305 + shift_factor = 0.0609 + # def process_out(self, latent): + # return (latent / self.scale_factor) + self.shift_factor + latent = (latent / scale_factor) + shift_factor + return latent + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--do_not_use_t5xxl", action="store_true") + parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + # parser.add_argument( + # "--lora_weights", + # type=str, + # nargs="*", + # default=[], + # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", + # ) + # parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + # TODO test with separated safetenors files for each model + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + state_dict = load_file(args.ckpt_path) + + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_g from {args.clip_g}...") + clip_g_sd = load_file(args.clip_g) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info(f"Lodaing clip_l from {args.clip_l}...") + clip_l_sd = load_file(args.clip_l) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + if not args.do_not_use_t5xxl: + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k, v in list(state_dict.items()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + else: + logger.info("but not used") + for key in list(state_dict.keys()): + if key.startswith("text_encoders.t5xxl."): + state_dict.pop(key) + t5xxl_sd = None + elif args.t5xxl: + assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" + logger.info(f"Lodaing t5xxl from {args.t5xxl}...") + t5xxl_sd = load_file(args.t5xxl) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + logger.info("t5xxl is not used") + t5xxl_sd = None + + use_t5xxl = t5xxl_sd is not None + + # MMDiT and VAE + vae_sd = {} + vae_prefix = "first_stage_model." + mmdit_prefix = "model.diffusion_model." + for k, v in list(state_dict.items()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + elif k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load tokenizers + logger.info("Loading tokenizers...") + tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + + # load models + # logger.info("Create MMDiT from SD3 checkpoint...") + # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) + logger.info("Create MMDiT") + mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(state_dict) + logger.info(f"Loaded MMDiT: {info}") + + logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") + mmdit.to(device, dtype=sd3_dtype) + mmdit.eval() + + # load VAE + logger.info("Create VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + logger.info(f"Move VAE to {device} and {sd3_dtype}...") + vae.to(device, dtype=sd3_dtype) + vae.eval() + + # load text encoders + logger.info("Create clip_l") + clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) + + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded clip_l: {info}") + + logger.info(f"Move clip_l to {device} and {sd3_dtype}...") + clip_l.to(device, dtype=sd3_dtype) + clip_l.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_l.set_attn_mode(args.attn_mode) + + logger.info("Create clip_g") + clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) + + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded clip_g: {info}") + + logger.info(f"Move clip_g to {device} and {sd3_dtype}...") + clip_g.to(device, dtype=sd3_dtype) + clip_g.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + clip_g.set_attn_mode(args.attn_mode) + + if use_t5xxl: + logger.info("Create t5xxl") + t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded t5xxl: {info}") + + logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") + t5xxl.to(device, dtype=sd3_dtype) + # t5xxl.to("cpu", dtype=torch.float32) # run on CPU + t5xxl.eval() + logger.info(f"Set attn_mode to {args.attn_mode}...") + t5xxl.set_attn_mode(args.attn_mode) + else: + t5xxl = None + + # prepare embeddings + logger.info("Encoding prompts...") + # embeds, pooled_embed + cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + + # generate image + logger.info("Generating image...") + latent_sampled = do_sample( + target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device + ) + + # latent to image + with torch.no_grad(): + image = vae.decode(latent_sampled) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") From 25f961bc779bc79aef440813e3e8e92244ac5739 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jun 2024 13:24:30 +0900 Subject: [PATCH 067/356] fix to work cache_latents/text_encoder_outputs --- README.md | 6 ++++++ tools/cache_latents.py | 5 ++++- tools/cache_text_encoder_outputs.py | 5 ++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a7047a36..fd81a781 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Jun 23, 2024 / 2024-06-23: + +- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) + +- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。) + ### Apr 7, 2024 / 2024-04-07: v0.8.7 - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f..32101de3 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,12 +16,13 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) # check cache latents arg @@ -94,6 +95,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: # acceleratorを準備する logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -170,6 +172,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5f1d6d20..a75d9da7 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,12 +16,13 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) # check cache arg @@ -99,6 +100,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: # acceleratorを準備する logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -171,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) From d53ea22b2a8366e6bc9f14aaeec057cd817f60d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jun 2024 23:38:20 +0900 Subject: [PATCH 068/356] sd3 training --- README.md | 25 + library/sai_model_spec.py | 20 +- library/sd3_models.py | 102 ++++- library/sd3_train_utils.py | 544 ++++++++++++++++++++++ library/sd3_utils.py | 211 ++++++++- library/train_util.py | 137 +++++- sd3_minimal_inference.py | 7 +- sd3_train.py | 907 +++++++++++++++++++++++++++++++++++++ 8 files changed, 1909 insertions(+), 44 deletions(-) create mode 100644 library/sd3_train_utils.py create mode 100644 sd3_train.py diff --git a/README.md b/README.md index 946df58f..34aa2bb2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,30 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## SD3 training + +SD3 training is done with `sd3_train.py`. + +`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. + +`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. + +t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. + +There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. + +```toml +learning_rate = 1e-5 # seems to be too high +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +cache_text_encoder_outputs = true +cache_text_encoder_outputs_to_disk = true +vae_batch_size = 1 +cache_latents = true +cache_latents_to_disk = true +``` + +--- + [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82e..f7bf644d 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -6,8 +6,10 @@ import os from typing import List, Optional, Tuple, Union import safetensors from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) r""" @@ -55,11 +57,14 @@ ARCH_SD_V1 = "stable-diffusion-v1" ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_SD3_M = "stable-diffusion-3-medium" +ARCH_SD3_UNKNOWN = "stable-diffusion-3" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" PRED_TYPE_EPSILON = "epsilon" @@ -113,7 +118,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + sd3: str = None, ): + """ + sd3: only supports "m" + """ # if state_dict is None, hash is not calculated metadata = {} @@ -126,6 +135,11 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE + elif sd3 is not None: + if sd3 == "m": + arch = ARCH_SD3_M + else: + arch = ARCH_SD3_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -142,7 +156,7 @@ def build_metadata( metadata["modelspec.architecture"] = arch if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA @@ -236,7 +250,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): logger.error(f"Internal error: some metadata values are None: {metadata}") - + return metadata @@ -250,7 +264,7 @@ def get_title(metadata: dict) -> Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/sd3_models.py b/library/sd3_models.py index 294a69b0..a4fe400e 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1,11 +1,13 @@ -# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref # the original code is licensed under the MIT License # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! +from ast import Tuple from functools import partial import math -from typing import Dict, Optional +from types import SimpleNamespace +from typing import Dict, List, Optional, Union import einops import numpy as np import torch @@ -106,6 +108,8 @@ class SD3Tokenizer: self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) self.t5xxl = T5XXLTokenizer() if t5xxl else None + # t5xxl has 99999999 max length, clip has 77 + self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): return ( @@ -870,6 +874,10 @@ class MMDiT(nn.Module): self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) # self.initialize_weights() + @property + def model_type(self): + return "m" # only support medium + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: @@ -1013,6 +1021,10 @@ def create_mmdit_sd3_medium_configs(attn_mode: str): # endregion # region VAE +# TODO support xformers + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): @@ -1222,6 +1234,14 @@ class SDVAE(torch.nn.Module): self.encoder = VAEEncoder(dtype=dtype, device=device) self.decoder = VAEDecoder(dtype=dtype, device=device) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) @@ -1234,6 +1254,43 @@ class SDVAE(torch.nn.Module): std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +class VAEOutput: + def __init__(self, latent): + self.latent = latent + + @property + def latent_dist(self): + return self + + def sample(self): + return self.latent + + +class VAEWrapper: + def __init__(self, vae): + self.vae = vae + + @property + def device(self): + return self.vae.device + + @property + def dtype(self): + return self.vae.dtype + + # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + def encode(self, image): + return VAEOutput(self.vae.encode(image)) + # endregion @@ -1370,15 +1427,39 @@ class CLIPTextModel(torch.nn.Module): class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - out, pooled = self([tokens]) - if pooled is not None: - first_pooled = pooled[0:1].cpu() + # def encode_token_weights(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + # out, pooled = self([tokens]) + # if pooled is not None: + # first_pooled = pooled[0:1] + # else: + # first_pooled = pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + # fix to support batched inputs + # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] + def encode_token_weights(self, list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + out, pooled = self(list_of_tokens) + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): @@ -1694,6 +1775,7 @@ class T5Stack(torch.nn.Module): x = self.embed_tokens(input_ids) past_bias = None for i, l in enumerate(self.block): + # uncomment to debug layerwise output: fp16 may cause issues # print(i, x.mean(), x.std()) x, past_bias = l(x, past_bias) if i == intermediate_output: diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 00000000..4e45871f --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,544 @@ +import argparse +import math +import os +from typing import Optional, Tuple + +import torch +from safetensors.torch import save_file + +from library import sd3_models, sd3_utils, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate import init_empty_weights +from tqdm import tqdm + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .sdxl_train_util import match_mixed_precision + + +def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ + sd3_models.MMDiT, + Optional[sd3_models.SDClipModel], + Optional[sd3_models.SDXLClipG], + Optional[sd3_models.T5XXLModel], + sd3_models.SDVAE, +]: + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( + args.pretrained_model_name_or_path, + args.clip_l, + args.clip_g, + args.t5xxl, + args.vae, + attn_mode, + accelerator.device if args.lowram else "cpu", + weight_dtype, + args.disable_mmap_load_safetensors, + t5xxl_device, + t5xxl_dtype, + ) + + # work on low-ram device + if args.lowram: + if clip_l is not None: + clip_l.to(accelerator.device) + if clip_g is not None: + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + vae.to(accelerator.device) + mmdit.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() + + return mmdit, clip_l, clip_g, t5xxl, vae + + +def save_models( + ckpt_path: str, + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + if clip_l is not None: + update_sd("text_encoders.clip_l.", clip_l.state_dict()) + if clip_g is not None: + update_sd("text_encoders.clip_g.", clip_g.state_dict()) + if t5xxl is not None: + update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + ) + parser.add_argument( + "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +def sample_images(*args, **kwargs): + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +# endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 6f8c361f..c2c91412 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -1,30 +1,226 @@ import math -from typing import Dict +from typing import Dict, Optional, Union import torch +import safetensors +from safetensors.torch import load_file +from accelerate import init_empty_weights + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) from library import sd3_models +# TODO move some of functions to model_util.py +from library import sdxl_model_util + +# region models + + +def load_models( + ckpt_path: str, + clip_l_path: str, + clip_g_path: str, + t5xxl_path: str, + vae_path: str, + attn_mode: str, + device: Union[str, torch.device], + weight_dtype: torch.dtype, + disable_mmap: bool = False, + t5xxl_device: Optional[str] = None, + t5xxl_dtype: Optional[str] = None, +): + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + t5xxl_device = t5xxl_device or device + + logger.info(f"Loading SD3 models from {ckpt_path}...") + state_dict = load_state_dict(ckpt_path) + + # load clip_l + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_state_dict(clip_l_path) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load clip_g + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_state_dict(clip_g_path) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + # load t5xxl + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + # MMDiT and VAE + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_state_dict(vae_path) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) + else: + state_dict.pop(k) # remove other keys + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + logger.info(f"Loaded MMDiT: {info}") + + # load ClipG and ClipL + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + + # load T5XXL + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + + # load VAE + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + + return mmdit, clip_l, clip_g, t5xxl, vae + + +# endregion +# region utils + def get_cond( prompt: str, tokenizer: sd3_models.SD3Tokenizer, clip_l: sd3_models.SDClipModel, clip_g: sd3_models.SDXLClipG, - t5xxl: sd3_models.T5XXLModel, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) + + +def get_cond_from_tokens( + l_tokens, + g_tokens, + t5_tokens, + clip_l: sd3_models.SDClipModel, + clip_g: sd3_models.SDXLClipG, + t5xxl: Optional[sd3_models.T5XXLModel] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): l_out, l_pooled = clip_l.encode_token_weights(l_tokens) g_out, g_pooled = clip_g.encode_token_weights(g_tokens) lg_out = torch.cat([l_out, g_out], dim=-1) lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if device is not None: + lg_out = lg_out.to(device=device) + l_pooled = l_pooled.to(device=device) + g_pooled = g_pooled.to(device=device) + if dtype is not None: + lg_out = lg_out.to(dtype=dtype) + l_pooled = l_pooled.to(dtype=dtype) + g_pooled = g_pooled.to(dtype=dtype) + # t5xxl may be in another device (eg. cpu) if t5_tokens is None: - t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device) + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) else: - t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None - t5_out = t5_out.to(lg_out.dtype) + t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None + if device is not None: + t5_out = t5_out.to(device=device) + if dtype is not None: + t5_out = t5_out.to(dtype=dtype) - return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) + return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) # used if other sd3 models is available @@ -111,3 +307,6 @@ class ModelSamplingDiscreteFlow: # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image + + +# endregion diff --git a/library/train_util.py b/library/train_util.py index 4736ff4f..c67e8737 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -58,7 +58,7 @@ from diffusers import ( KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions +from library import custom_train_functions, sd3_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -135,6 +135,7 @@ IMAGE_TRANSFORMS = transforms.Compose( ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" class ImageInfo: @@ -985,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1006,7 +1007,7 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix if not is_main_process: # store to info only continue @@ -1040,14 +1041,43 @@ class BaseDataset(torch.utils.data.Dataset): for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True ): assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + ) + + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") @@ -1058,13 +1088,14 @@ class BaseDataset(torch.utils.data.Dataset): for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only continue if os.path.exists(te_out_npz): + # TODO check varidity of cache here continue image_infos_to_cache.append(info) @@ -1073,18 +1104,23 @@ class BaseDataset(torch.utils.data.Dataset): return # prepare tokenizers and text encoders - for text_encoder in text_encoders: + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) # create batch + is_sd3 = len(tokenizers) == 1 batch = [] batches = [] for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= self.batch_size: batches.append(batch) @@ -1095,13 +1131,32 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) def get_image_size(self, image_path): return imagesize.get(image_path) @@ -1332,6 +1387,7 @@ class BaseDataset(torch.utils.data.Dataset): captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future + # TODO get_input_ids must support SD3 if self.XTI_layers: token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: @@ -2140,10 +2196,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2152,6 +2208,15 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + ) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2585,6 +2650,30 @@ def cache_batch_text_encoder_outputs( info.text_encoder_pool2 = pool2 +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, @@ -2907,6 +2996,7 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, ): timestamp = time.time() @@ -2940,6 +3030,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + sd3=sd3, ) return metadata diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e14f784d..96e9da4a 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -320,8 +320,11 @@ if __name__ == "__main__": # prepare embeddings logger.info("Encoding prompts...") # embeds, pooled_embed - cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) + cond = torch.cat([lg_out, t5_out], dim=-2), pooled + + lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) + neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py new file mode 100644 index 00000000..0721b2ae --- /dev/null +++ b/sd3_train.py @@ -0,0 +1,907 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils + +# , sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions + +# from library.custom_train_functions import ( +# apply_snr_weight, +# prepare_scheduler_for_custom_training, +# scale_v_prediction_loss_like_noise_prediction, +# add_v_prediction_like_loss, +# apply_debiased_estimation, +# apply_masked_loss, +# ) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.train_text_encoder or not args.cache_text_encoder_outputs + ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # if args.block_lr: + # block_lrs = [float(lr) for lr in args.block_lr.split(",")] + # assert ( + # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + # else: + # block_lrs = None + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # load tokenizer + sd3_tokenizer = sd3_models.SD3Tokenizer() + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 + + t5xxl_dtype = weight_dtype + if args.t5xxl_dtype is not None: + if args.t5xxl_dtype == "fp16": + t5xxl_dtype = torch.float16 + elif args.t5xxl_dtype == "bf16": + t5xxl_dtype = torch.bfloat16 + elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + t5xxl_dtype = torch.float32 + else: + raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + + # モデルを読み込む + attn_mode = "xformers" if args.xformers else "torch" + + assert ( + attn_mode == "torch" + ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + ) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + train_mmdit = args.learning_rate != 0 + train_clip_l = False + train_clip_g = False + train_t5xxl = False + + # if args.train_text_encoder: + # # TODO each option for two text encoders? + # accelerator.print("enable text encoder training") + # if args.gradient_checkpointing: + # text_encoder1.gradient_checkpointing_enable() + # text_encoder2.gradient_checkpointing_enable() + # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + # train_clip_l = lr_te1 != 0 + # train_clip_g = lr_te2 != 0 + + # # caching one text encoder output is not supported + # if not train_clip_l: + # text_encoder1.to(weight_dtype) + # if not train_clip_g: + # text_encoder2.to(weight_dtype) + # text_encoder1.requires_grad_(train_clip_l) + # text_encoder2.requires_grad_(train_clip_g) + # text_encoder1.train(train_clip_l) + # text_encoder2.train(train_clip_g) + # else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() + if t5xxl is not None: + t5xxl.to(t5xxl_dtype) + t5xxl.requires_grad_(False) + t5xxl.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs_sd3( + sd3_tokenizer, + (clip_l, clip_g, t5xxl), + (accelerator.device, accelerator.device, t5xxl_device), + None, + (None, None, None), + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + training_models = [] + params_to_optimize = [] + # if train_unet: + training_models.append(mmdit) + # if block_lrs is None: + params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + # else: + # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) + + # if train_clip_l: + # training_models.append(text_encoder1) + # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # if train_clip_g: + # training_models.append(text_encoder2) + # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + mmdit.to(weight_dtype) + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + + # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g + # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + # if train_clip_l: + # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + mmdit=mmdit, + # mmdie=mmdit if train_mmdit else None, + # text_encoder1=text_encoder1 if train_clip_l else None, + # text_encoder2=text_encoder2 if train_clip_g else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_mmdit: + mmdit = accelerator.prepare(mmdit) + # if train_clip_l: + # text_encoder1 = accelerator.prepare(text_encoder1) + # if train_clip_g: + # text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # noise_scheduler = DDPMScheduler( + # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + # ) + + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + # if args.zero_terminal_snr: + # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # # For --sample_at_first + # sd3_train_utils.sample_images( + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # ) + + # following function will be moved to sd3_train_utils + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = sd3_models.SDVAE.process_in(latents) + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + # not cached, get text encoder outputs + # XXX This does not work yet + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + with torch.set_grad_enabled(args.train_text_encoder): + # TODO support weighted captions + # TODO support length > 75 + input_ids_clip_l = input_ids_clip_l.to(accelerator.device) + input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) + + # get text encoder outputs: outputs are concatenated + context, pool = sd3_utils.get_cond_from_tokens( + input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + ) + else: + # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + # TODO this reuses SDXL keys, it should be fixed + lg_out = batch["text_encoder_outputs1_list"] + t5_out = batch["text_encoder_outputs2_list"] + pool = batch["text_encoder_pool2_list"] + context = torch.cat([lg_out, t5_out], dim=-2) + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # call model + with accelerator.autocast(): + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # Compute regular loss. TODO simplify this + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # epoch + 1, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # mmdit, + # ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + mmdit = accelerator.unwrap_model(mmdit) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sd3_train_utils.save_sd3_model_on_train_end( + args, + save_dtype, + epoch, + global_step, + clip_l if args.save_clip else None, + clip_g if args.save_clip else None, + t5xxl if args.save_t5xxl else None, + mmdit, + vae, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + + # TE training is disabled temporarily + + # parser.add_argument( + # "--learning_rate_te1", + # type=float, + # default=None, + # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + # ) + # parser.add_argument( + # "--learning_rate_te2", + # type=float, + # default=None, + # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + # ) + + # parser.add_argument( + # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + # ) + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # parser.add_argument( + # "--no_half_vae", + # action="store_true", + # help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + # ) + # parser.add_argument( + # "--block_lr", + # type=str, + # default=None, + # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + # ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) From 0fe4eafac996fa5139a311aadc86aca28ddc6930 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 24 Jun 2024 23:12:48 +0900 Subject: [PATCH 069/356] fix to use zero for initial latent --- sd3_minimal_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 96e9da4a..7f5f28ce 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,8 @@ def do_sample( device: str, ): if initial_latent is None: - latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent From 4802e4aaec74429f733fae289e41c5618ebb0e92 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 24 Jun 2024 23:13:14 +0900 Subject: [PATCH 070/356] workaround for long caption ref #1382 --- library/sd3_models.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a4fe400e..c19aec6a 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -56,7 +56,7 @@ class SDTokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 - def tokenize_with_weights(self, text: str): + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" """ @@ -79,6 +79,14 @@ class SDTokenizer: batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + + # truncate to max_length + # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + if truncate_to_max_length and len(batch) > self.max_length: + batch = batch[: self.max_length] + if truncate_length is not None and len(batch) > truncate_length: + batch = batch[:truncate_length] + return [batch] @@ -112,10 +120,15 @@ class SD3Tokenizer: self.model_max_length = self.clip_l.max_length # 77 def tokenize_with_weights(self, text: str): + # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), - self.t5xxl.tokenize_with_weights(text) if self.t5xxl is not None else None, + ( + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + if self.t5xxl is not None + else None + ), ) From 0b3e4f7ab62b7c93e66972b7bd2774b8fe679792 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 25 Jun 2024 20:03:09 +0900 Subject: [PATCH 071/356] show file name if error in load_image ref #1385 --- library/train_util.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4736ff4f..760be33e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2434,16 +2434,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): - image = Image.open(image_path) - if alpha: - if not image.mode == "RGBA": - image = image.convert("RGBA") - else: - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img +def load_image(image_path, alpha=False): + try: + with Image.open(image_path) as image: + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + except (IOError, OSError) as e: + logger.error(f"Error loading file: {image_path}") + raise e # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) From 8f2ba27869e4c5b9225a309aeed275a47d8eed6a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:36:22 +0900 Subject: [PATCH 072/356] support text_encoder_batch_size for caching --- library/sd3_train_utils.py | 7 +++++++ library/train_util.py | 14 ++++++++++---- sd3_train.py | 1 + 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 4e45871f..70c83c0b 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -173,6 +173,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", diff --git a/library/train_util.py b/library/train_util.py index c67e8737..96d32e3b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1054,7 +1054,7 @@ class BaseDataset(torch.utils.data.Dataset): # same as above, but for SD3 def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): return self.cache_text_encoder_outputs_common( [tokenizer], @@ -1065,6 +1065,7 @@ class BaseDataset(torch.utils.data.Dataset): cache_to_disk, is_main_process, TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, ) def cache_text_encoder_outputs_common( @@ -1077,10 +1078,15 @@ class BaseDataset(torch.utils.data.Dataset): cache_to_disk=False, is_main_process=True, file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1122,7 +1128,7 @@ class BaseDataset(torch.utils.data.Dataset): l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -2209,12 +2215,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def cache_text_encoder_outputs_sd3( - self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None ): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs_sd3( - tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) def set_caching_mode(self, caching_mode): diff --git a/sd3_train.py b/sd3_train.py index 0721b2ae..8216a62b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -254,6 +254,7 @@ def train(args): (None, None, None), args.cache_text_encoder_outputs_to_disk, accelerator.is_main_process, + args.text_encoder_batch_size, ) accelerator.wait_for_everyone() From 828a581e2968935c00d22e7e03ca32c1281aa5dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 20:43:31 +0900 Subject: [PATCH 073/356] fix assertion for experimental impl ref #1389 --- sd3_train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 8216a62b..ea9a1104 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -60,9 +60,19 @@ def train(args): assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + # training text encoder is not supported assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + not args.train_text_encoder + ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + + # training without text encoder cache is not supported + assert ( + args.cache_text_encoder_outputs + ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] From 381598c8bbd3d4e50ec4327fa27d5d0072ec2a67 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 26 Jun 2024 21:15:02 +0900 Subject: [PATCH 074/356] fix resolution in metadata for sd3 --- library/sai_model_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index f7bf644d..af073677 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -216,7 +216,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl: + if sdxl or sd3 is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 From 66cf43547972647389fbd2addb53cff2ab478660 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 27 Jun 2024 13:14:09 +0900 Subject: [PATCH 075/356] re-fix assertion ref #1389 --- sd3_train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index ea9a1104..b6c932c4 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -64,10 +64,10 @@ def train(args): # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - # training text encoder is not supported - assert ( - not args.train_text_encoder - ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" + # # training text encoder is not supported + # assert ( + # not args.train_text_encoder + # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" # training without text encoder cache is not supported assert ( From 19086465e8040c01c38d38eec5c53f966f0dad8b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 29 Jun 2024 17:21:25 +0900 Subject: [PATCH 076/356] Fix fp16 mixed precision, model is in bf16 without full_bf16 --- README.md | 11 +++++++-- library/sd3_train_utils.py | 10 +++++---- library/sd3_utils.py | 46 +++++++++++++++++++++++++++++++++----- sd3_train.py | 9 +++++--- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 34aa2bb2..3eed636c 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. +__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). + +`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. + `optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. +t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. +`text_encoder_batch_size` is added experimentally for caching faster. + ```toml -learning_rate = 1e-5 # seems to be too high +learning_rate = 1e-6 # seems to depend on the batch size optimizer_type = "adafactor" optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] cache_text_encoder_outputs = true cache_text_encoder_outputs_to_disk = true vae_batch_size = 1 +text_encoder_batch_size = 4 cache_latents = true cache_latents_to_disk = true ``` diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 70c83c0b..c8d52e1c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -28,14 +28,14 @@ logger = logging.getLogger(__name__) from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ +def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: @@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, args.vae, attn_mode, accelerator.device if args.lowram else "cpu", - weight_dtype, + model_dtype, args.disable_mmap_load_safetensors, + clip_dtype, t5xxl_device, t5xxl_dtype, + vae_dtype, ) - # work on low-ram device + # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: if clip_l is not None: clip_l.to(accelerator.device) diff --git a/library/sd3_utils.py b/library/sd3_utils.py index c2c91412..45b49b04 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,11 +28,41 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - weight_dtype: torch.dtype, + default_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, - t5xxl_device: Optional[str] = None, - t5xxl_dtype: Optional[str] = None, + clip_dtype: Optional[Union[str, torch.dtype]] = None, + t5xxl_device: Optional[Union[str, torch.device]] = None, + t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, + vae_dtype: Optional[Union[str, torch.dtype]] = None, ): + """ + Load SD3 models from checkpoint files. + + Args: + ckpt_path: Path to the SD3 checkpoint file. + clip_l_path: Path to the clip_l checkpoint file. + clip_g_path: Path to the clip_g checkpoint file. + t5xxl_path: Path to the t5xxl checkpoint file. + vae_path: Path to the VAE checkpoint file. + attn_mode: Attention mode for MMDiT model. + device: Device for MMDiT model. + default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + disable_mmap: Disable memory mapping when loading state dict. + clip_dtype: Dtype for Clip models, or None to use default dtype. + t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. + t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. + vae_dtype: Dtype for VAE model, or None to use default dtype. + + Returns: + Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. + """ + + # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. + # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. + # Therefore, we need clip_dtype and t5xxl_dtype. + + # default_dtype is used for full_fp16/full_bf16 training. + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -43,6 +73,9 @@ def load_models( return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device + clip_dtype = clip_dtype or default_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 + vae_dtype = vae_dtype or default_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -124,7 +157,7 @@ def load_models( mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL @@ -132,7 +165,7 @@ def load_models( clip_l = None else: logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) logger.info("Loading state dict...") info = clip_l.load_state_dict(clip_l_sd) logger.info(f"Loaded ClipL: {info}") @@ -142,7 +175,7 @@ def load_models( clip_g = None else: logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) logger.info("Loading state dict...") info = clip_g.load_state_dict(clip_g_sd) logger.info(f"Loaded ClipG: {info}") @@ -165,6 +198,7 @@ def load_models( logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) return mmdit, clip_l, clip_g, t5xxl, vae diff --git a/sd3_train.py b/sd3_train.py index b6c932c4..bd30cdc7 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,6 +182,8 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + clip_dtype = weight_dtype # if not args.train_text_encoder else None + # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -189,8 +191,9 @@ def train(args): attn_mode == "torch" ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -868,8 +871,9 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # TE training is disabled temporarily + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # TE training is disabled temporarily # parser.add_argument( # "--learning_rate_te1", # type=float, @@ -886,7 +890,6 @@ def setup_parser() -> argparse.ArgumentParser: # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" # ) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") # parser.add_argument( # "--no_half_vae", # action="store_true", From ea18d5ba6d856995d5c44be4b449b63ac66fe5db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 29 Jun 2024 17:45:50 +0900 Subject: [PATCH 077/356] Fix to work full_bf16 and full_fp16. --- library/sd3_models.py | 8 ++++++++ library/sd3_utils.py | 14 ++++++-------- sd3_train.py | 20 ++++++++++---------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index c19aec6a..7041420c 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -891,6 +891,14 @@ class MMDiT(nn.Module): def model_type(self): return "m" # only support medium + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 45b49b04..9dc9e796 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,7 +28,7 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - default_dtype: Optional[Union[str, torch.dtype]] = None, + weight_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, clip_dtype: Optional[Union[str, torch.dtype]] = None, t5xxl_device: Optional[Union[str, torch.device]] = None, @@ -46,7 +46,7 @@ def load_models( vae_path: Path to the VAE checkpoint file. attn_mode: Attention mode for MMDiT model. device: Device for MMDiT model. - default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. disable_mmap: Disable memory mapping when loading state dict. clip_dtype: Dtype for Clip models, or None to use default dtype. t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. @@ -61,8 +61,6 @@ def load_models( # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. # Therefore, we need clip_dtype and t5xxl_dtype. - # default_dtype is used for full_fp16/full_bf16 training. - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -73,9 +71,9 @@ def load_models( return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or default_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 - vae_dtype = vae_dtype or default_dtype or torch.float32 + clip_dtype = clip_dtype or weight_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 + vae_dtype = vae_dtype or weight_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -157,7 +155,7 @@ def load_models( mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL diff --git a/sd3_train.py b/sd3_train.py index bd30cdc7..de763ac6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,7 +182,7 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - clip_dtype = weight_dtype # if not args.train_text_encoder else None + clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -193,7 +193,7 @@ def train(args): # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -769,10 +769,10 @@ def train(args): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) @@ -807,10 +807,10 @@ def train(args): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) From 50e3d6247459c9f59facaef42e03b34cd8d6287d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 19:46:23 +0900 Subject: [PATCH 078/356] fix to work T5XXL with fp16 --- library/sd3_models.py | 144 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 7041420c..e4c0790d 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1124,7 +1124,12 @@ class Upsample(torch.nn.Module): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) x = self.conv(x) return x @@ -1263,11 +1268,11 @@ class SDVAE(torch.nn.Module): def dtype(self): return next(self.parameters()).dtype - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def decode(self, latent): return self.decoder(latent) - @torch.autocast("cuda", dtype=torch.float16) + # @torch.autocast("cuda", dtype=torch.float16) def encode(self, image): hidden = self.encoder(image) mean, logvar = torch.chunk(hidden, 2, dim=1) @@ -1630,10 +1635,25 @@ class T5LayerNorm(torch.nn.Module): self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) self.variance_epsilon = eps - def forward(self, x): - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - return self.weight.to(device=x.device, dtype=x.dtype) * x + # def forward(self, x): + # variance = x.pow(2).mean(-1, keepdim=True) + # x = x * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight.to(device=x.device, dtype=x.dtype) * x + + # copy from transformers' T5LayerNorm + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states class T5DenseGatedActDense(torch.nn.Module): @@ -1775,7 +1795,27 @@ class T5Block(torch.nn.Module): def forward(self, x, past_bias=None): x, past_bias = self.layer[0](x, past_bias) + + # copy from transformers' T5Block + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + x = self.layer[-1](x) + # clamp inf values to enable fp16 training + if x.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(x).any(), + torch.finfo(x.dtype).max - 1000, + torch.finfo(x.dtype).max, + ) + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + return x, past_bias @@ -1896,4 +1936,96 @@ def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[st return t5 +""" + # snippet for using the T5 model from transformers + + from transformers import T5EncoderModel, T5Config + import accelerate + import json + + T5_CONFIG_JSON = "" +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +"" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + + # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") + # print(model.config) + # # model(**load_model.config) + + # with accelerate.init_empty_weights(): + model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) + for key in list(state_dict.keys()): + if key.startswith("transformer."): + new_key = key[len("transformer.") :] + state_dict[new_key] = state_dict.pop(key) + + info = model.load_state_dict(state_dict) + print(info) + model.set_attn_mode = lambda x: None + # model.to("cpu") + + _self = model + + def enc(list_of_token_weight_pairs): + has_batch = isinstance(list_of_token_weight_pairs[0][0], list) + + if has_batch: + list_of_tokens = [] + for pairs in list_of_token_weight_pairs: + tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] + list_of_tokens.append(tokens) + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + + list_of_tokens = np.array(list_of_tokens) + list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) + out = _self(list_of_tokens) + pooled = None + if has_batch: + return out, pooled + else: + if pooled is not None: + first_pooled = pooled[0:1] + else: + first_pooled = pooled + return out[0], first_pooled + # output = [out[0:1]] + # return torch.cat(output, dim=-2), first_pooled + + model.encode_token_weights = enc + + return model +""" + # endregion From c9de7c4e9a3d02ab6f18f105c880a9ba88b667ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 19:48:28 +0900 Subject: [PATCH 079/356] WIP: new latents caching --- library/sd3_train_utils.py | 94 +++++++++++++++++++++++- library/train_util.py | 147 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 37 +++++++++- 3 files changed, 270 insertions(+), 8 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c8d52e1c..9309ee30 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,7 +1,7 @@ import argparse import math import os -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from safetensors.torch import save_file @@ -283,6 +283,98 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = vae + + def get_latents_npz_path(self, absolute_path: str): + return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) + + with torch.no_grad(): + latents = self.vae.encode(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = self.vae.encode(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): + if self.cache_to_disk: + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + kwargs = {} + if flipped_latent is not None: + kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + info.latents_npz, + latents=latents.float().cpu().numpy(), + original_size=np.array(original_sizes), + crop_ltrb=np.array(crop_ltrbs), + **kwargs, + ) + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + if not train_util.HIGH_VRAM: + clean_memory_on_device(self.vae.device) + + # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index 96d32e3b..8444827d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -359,6 +359,30 @@ class AugHelper: return self.color_aug if use_color_aug else None +class LatentsCachingStrategy: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_latents_npz_path(self, absolute_path: str): + raise NotImplementedError + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + raise NotImplementedError + + def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + class BaseSubset: def __init__( self, @@ -986,6 +1010,69 @@ class BaseDataset(torch.utils.data.Dataset): ] ) + def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + if not is_main_process: # prepare for multi-gpu, only store to info + continue + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # if cache to disk, don't cache latents in non-main process, set to info only + if caching_strategy.cache_to_disk and not is_main_process: + return + + if len(batches) == 0: + logger.info("no latents to cache") + return + + # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded + logger.info("caching latents...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1086,7 +1173,7 @@ class BaseDataset(torch.utils.data.Dataset): if batch_size is None: batch_size = self.batch_size - + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -2207,6 +2294,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(is_main_process, strategy) + def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): @@ -2550,6 +2642,51 @@ def trim_and_resize_if_required( return image, original_size, crop_ltrb +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + def cache_batch_latents( vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: @@ -2661,7 +2798,7 @@ def cache_batch_text_encoder_outputs_sd3( ): # make input_ids for each text encoder l_tokens, g_tokens, t5_tokens = input_ids - + clip_l, clip_g, t5xxl = text_encoders with torch.no_grad(): b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( @@ -2670,8 +2807,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index de763ac6..c073ec0e 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -204,11 +204,22 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" + + if not args.new_caching: + vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible + with torch.no_grad(): + train_dataset_group.cache_latents( + vae_wrapper, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + file_suffix="_sd3.npz", + ) + else: + strategy = sd3_train_utils.Sd3LatensCachingStrategy( + vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) + train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -699,6 +710,17 @@ def train(args): sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # debug: NaN check for all inputs + if torch.any(torch.isnan(noisy_model_input)): + accelerator.print("NaN found in noisy_model_input, replacing with zeros") + noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + if torch.any(torch.isnan(context)): + accelerator.print("NaN found in context, replacing with zeros") + context = torch.nan_to_num(context, 0, out=context) + if torch.any(torch.isnan(pool)): + accelerator.print("NaN found in pool, replacing with zeros") + pool = torch.nan_to_num(pool, 0, out=pool) + # call model with accelerator.autocast(): model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) @@ -908,6 +930,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) + + parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) return parser From 3ea4fce5e0f3d1a9c2718d77f49c3b304d25e565 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Jul 2024 22:04:43 +0900 Subject: [PATCH 080/356] load models one by one --- library/sd3_train_utils.py | 56 ++++++------ library/sd3_utils.py | 169 +++++++++++++++++++++++++++++++++++++ sd3_train.py | 58 +++++++++---- 3 files changed, 236 insertions(+), 47 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9309ee30..98ee66bf 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,19 +1,17 @@ import argparse import math import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from safetensors.torch import save_file +from accelerate import Accelerator from library import sd3_models, sd3_utils, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from accelerate import init_empty_weights -from tqdm import tqdm - # from transformers import CLIPTokenizer # from library import model_util # , sdxl_model_util, train_util, sdxl_original_unet @@ -28,50 +26,48 @@ logger = logging.getLogger(__name__) from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ +def load_target_model( + model_type: str, + args: argparse.Namespace, + state_dict: dict, + accelerator: Accelerator, + attn_mode: str, + model_dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> Union[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 + loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( - args.pretrained_model_name_or_path, - args.clip_l, - args.clip_g, - args.t5xxl, - args.vae, - attn_mode, - accelerator.device if args.lowram else "cpu", - model_dtype, - args.disable_mmap_load_safetensors, - clip_dtype, - t5xxl_device, - t5xxl_dtype, - vae_dtype, - ) + if model_type == "mmdit": + model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) + elif model_type == "clip_l": + model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) + elif model_type == "clip_g": + model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) + elif model_type == "t5xxl": + model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) + elif model_type == "vae": + model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) + else: + raise ValueError(f"Unknown model type: {model_type}") # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: - if clip_l is not None: - clip_l.to(accelerator.device) - if clip_g is not None: - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - vae.to(accelerator.device) - mmdit.to(accelerator.device) + model = model.to(accelerator.device) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - return mmdit, clip_l, clip_g, t5xxl, vae + return model def save_models( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 9dc9e796..16f80c60 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -20,6 +20,175 @@ from library import sdxl_model_util # region models +def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + +def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + state_dict: Dict, + clip_l_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + return clip_l + + +def load_clip_g( + state_dict: Dict, + clip_g_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + return clip_g + + +def load_t5xxl( + state_dict: Dict, + t5xxl_path: Optional[str], + attn_mode: str, + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + + # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device + t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) + t5xxl.to(dtype=dtype) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + return t5xxl + + +def load_vae( + state_dict: Dict, + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) + return vae + + def load_models( ckpt_path: str, clip_l_path: str, diff --git a/sd3_train.py b/sd3_train.py index c073ec0e..10cc5d57 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -13,12 +13,12 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device - init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -189,18 +189,19 @@ def train(args): assert ( attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = sd3_utils.load_safetensors( + args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors ) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - # 学習を準備する + # load VAE for caching latents + vae: sd3_models.SDVAE = None if cache_latents: + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() @@ -220,15 +221,25 @@ def train(args): vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) - vae.to("cpu") + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. + # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # ) + clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + + t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # 学習を準備する:モデルを適切な状態にする - if args.gradient_checkpointing: - mmdit.enable_gradient_checkpointing() - train_mmdit = args.learning_rate != 0 train_clip_l = False train_clip_g = False train_t5xxl = False @@ -280,17 +291,30 @@ def train(args): accelerator.is_main_process, args.text_encoder_batch_size, ) + + # TODO we can delete text encoders after caching accelerator.wait_for_everyone() + # load MMDIT + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + if not cache_latents: + # load VAE here if not cached + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - mmdit.requires_grad_(train_mmdit) - if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - training_models = [] params_to_optimize = [] # if train_unet: From 9dc7997803d70c718969526352e88908e827f091 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 9 Jul 2024 20:37:00 +0900 Subject: [PATCH 081/356] fix typo --- library/sd3_models.py | 2 +- library/sd3_train_utils.py | 2 +- sd3_train.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index e4c0790d..a1ff1e75 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1643,7 +1643,7 @@ class T5LayerNorm(torch.nn.Module): # copy from transformers' T5LayerNorm def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 98ee66bf..66034210 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -279,7 +279,7 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): +class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: diff --git a/sd3_train.py b/sd3_train.py index 10cc5d57..30d994c7 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -217,7 +217,7 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatensCachingStrategy( + strategy = sd3_train_utils.Sd3LatentsCachingStrategy( vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) From 3d402927efb2d396f8f33fe6a1747e43f7a5f0f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 9 Jul 2024 23:15:38 +0900 Subject: [PATCH 082/356] WIP: update new latents caching --- library/sd3_train_utils.py | 49 +++++++++++++++++++++++++------------- library/train_util.py | 39 ++++++++++++++++++++++++++---- sd3_train.py | 15 ++++++++---- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 66034210..24591219 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,4 +1,5 @@ import argparse +import glob import math import os from typing import List, Optional, Tuple, Union @@ -282,12 +283,26 @@ def sample_images(*args, **kwargs): class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = None + + def set_vae(self, vae: sd3_models.SDVAE): self.vae = vae - def get_latents_npz_path(self, absolute_path: str): - return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): if not self.cache_to_disk: @@ -331,24 +346,24 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) with torch.no_grad(): - latents = self.vae.encode(img_tensor).to("cpu") + latents_tensors = self.vae.encode(img_tensor).to("cpu") if flip_aug: img_tensor = torch.flip(img_tensor, dims=[3]) with torch.no_grad(): flipped_latents = self.vae.encode(img_tensor).to("cpu") else: - flipped_latents = [None] * len(latents) + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] - for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): if self.cache_to_disk: - # save_latents_to_disk( - # info.latents_npz, - # latent, - # info.latents_original_size, - # info.latents_crop_ltrb, - # flipped_latent, - # alpha_mask, - # ) kwargs = {} if flipped_latent is not None: kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() @@ -357,12 +372,12 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): np.savez( info.latents_npz, latents=latents.float().cpu().numpy(), - original_size=np.array(original_sizes), - crop_ltrb=np.array(crop_ltrbs), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), **kwargs, ) else: - info.latents = latent + info.latents = latents if flip_aug: info.latents_flipped = flipped_latent info.alpha_mask = alpha_mask diff --git a/library/train_util.py b/library/train_util.py index 8444827d..9db226ea 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -360,11 +360,23 @@ class AugHelper: class LatentsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + @property def cache_to_disk(self): return self._cache_to_disk @@ -373,10 +385,15 @@ class LatentsCachingStrategy: def batch_size(self): return self._batch_size - def get_latents_npz_path(self, absolute_path: str): + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: raise NotImplementedError - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: raise NotImplementedError def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -1034,7 +1051,7 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) if not is_main_process: # prepare for multi-gpu, only store to info continue @@ -1730,6 +1747,18 @@ class DreamBoothDataset(BaseDataset): img_paths = glob_images(subset.image_dir, "*") sizes = [None] * len(img_paths) + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from cache files") + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + w, h = strategy.get_image_size_from_image_absolute_path(img_path) + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2807,12 +2836,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): # debug: NaN check if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") - + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index 30d994c7..e2f622e4 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -91,6 +91,15 @@ def train(args): # load tokenizer sd3_tokenizer = sd3_models.SD3Tokenizer() + # prepare caching strategy + if args.new_caching: + latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + else: + latents_caching_strategy = None + train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -217,10 +226,8 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatentsCachingStrategy( - vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check - ) - train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) + latents_caching_strategy.set_vae(vae) + train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) From 6f0e235f2cb9a9829bc12280c29e12c0ae66c88f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:00:45 +0900 Subject: [PATCH 083/356] Fix shift value in SD3 inference. --- sd3_minimal_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 7f5f28ce..ffa0d46d 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,7 @@ def do_sample( device: str, ): if initial_latent is None: - # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent @@ -73,7 +73,7 @@ def do_sample( noise = get_noise(seed, latent).to(device) - model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 sigmas = get_sigmas(model_sampling, steps).to(device) # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i From b8896aad400222c8c4441b217fda0f9bb0807ffd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:01:23 +0900 Subject: [PATCH 084/356] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3eed636c..5d4f9621 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! + +Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -12,7 +14,7 @@ __Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. +~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. From 082f13658bdbaed872ede6c0a7a75ab1a5f3712d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 12 Jul 2024 21:28:01 +0900 Subject: [PATCH 085/356] reduce peak GPU memory usage before training --- library/sd3_models.py | 2 +- library/train_util.py | 1 + sd3_train.py | 44 +++++++++++++++++++++---------------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a1ff1e75..ec8e1bbd 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -471,7 +471,7 @@ class AttentionLinears(nn.Module): num_heads: int = 8, qkv_bias: bool = False, pre_only: bool = False, - qk_norm: str = None, + qk_norm: Optional[str] = None, ): super().__init__() self.num_heads = num_heads diff --git a/library/train_util.py b/library/train_util.py index 9db226ea..7af0070e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2410,6 +2410,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +# TODO update to use CachingStrategy def load_latents_from_disk( npz_path, ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: diff --git a/sd3_train.py b/sd3_train.py index e2f622e4..f34e4712 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -458,6 +458,28 @@ def train(args): # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args, @@ -482,28 +504,6 @@ def train(args): # text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - # TODO support CPU for text encoders - clip_l.to(accelerator.device) - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory - clean_memory_on_device(accelerator.device) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. From 87526942a67fd71bb775bc479b0a7449df516dd8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 12 Jul 2024 22:56:38 +0800 Subject: [PATCH 086/356] judge image size for using diff interpolation --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..74720fec 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) image_height, image_width = image.shape[0:2] From 2e67978ee243a20f169ce76d7644bb1f9dec9bad Mon Sep 17 00:00:00 2001 From: Millie Date: Thu, 18 Jul 2024 11:52:58 -0700 Subject: [PATCH 087/356] Generate sample images without having CUDA (such as on Macs) --- library/train_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..9b0397d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5229,7 +5229,7 @@ def sample_images_common( clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -5263,11 +5263,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5302,8 +5304,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] From 1f16b80e88b1c4f05d49b4fc328d3b9b105ebcbe Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:35:24 +0800 Subject: [PATCH 088/356] Revert "judge image size for using diff interpolation" This reverts commit 87526942a67fd71bb775bc479b0a7449df516dd8. --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 74720fec..15c23f3c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ image_height, image_width = image.shape[0:2] From 9ca7a5b6cc99e25820a1aa6d02a779004d73bca0 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:59:11 +0800 Subject: [PATCH 089/356] instead cv2 LANCZOS4 resize to pil resize --- finetune/tag_images_by_wd14_tagger.py | 8 +++++--- library/train_util.py | 11 ++++++----- library/utils.py | 14 +++++++++++++- tools/detect_face_rotate.py | 7 +++++-- tools/resize_images_to_resolution.py | 11 +++++++---- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd6..6f5bdd36 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..160e3b44 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -71,7 +71,7 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -2028,9 +2028,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2362,7 +2360,10 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] diff --git a/library/utils.py b/library/utils.py index 3037c055..a219f6cb 100644 --- a/library/utils.py +++ b/library/utils.py @@ -7,7 +7,9 @@ from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - +import cv2 +from PIL import Image +import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +def pil_resize(image, size, interpolation=Image.LANCZOS): + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # use Pillow resize + resized_pil = pil_image.resize(size, interpolation) + + # return cv2 image + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + + return resized_cv2 # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643ed..d2a4d9cf 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1..0f9e00b1 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import shutil import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2] From 41dee60383a3b88859b80929a2c0d94b12c42068 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jul 2024 13:50:05 +0900 Subject: [PATCH 090/356] Refactor caching mechanism for latents and text encoder outputs, etc. --- README.md | 21 +- fine_tune.py | 54 ++-- library/config_util.py | 2 - library/sd3_models.py | 47 +++- library/sd3_train_utils.py | 105 -------- library/sd3_utils.py | 1 + library/sdxl_train_util.py | 2 +- library/strategy_base.py | 328 +++++++++++++++++++++++ library/strategy_sd.py | 139 ++++++++++ library/strategy_sd3.py | 229 ++++++++++++++++ library/strategy_sdxl.py | 247 +++++++++++++++++ library/train_util.py | 443 +++++++++++++++---------------- sd3_minimal_inference.py | 22 +- sd3_train.py | 270 +++++++++++-------- sdxl_train.py | 108 ++++---- sdxl_train_control_net_lllite.py | 99 ++++--- sdxl_train_network.py | 48 +++- sdxl_train_textual_inversion.py | 47 ++-- train_db.py | 67 +++-- train_network.py | 122 ++++++--- train_textual_inversion.py | 118 ++++---- 21 files changed, 1786 insertions(+), 733 deletions(-) create mode 100644 library/strategy_base.py create mode 100644 library/strategy_sd.py create mode 100644 library/strategy_sd3.py create mode 100644 library/strategy_sdxl.py diff --git a/README.md b/README.md index 5d4f9621..d406fecd 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,16 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. -__Jul 11, 2024__: Fixed to work t5xxl with `fp16`. If you change the dtype to `fp16` for t5xxl, please remove existing latents cache files (`*_sd3.npz`). The shift in `sd3_minimum_inference.py` is fixed to 3.0. Thanks to araleza! +__Jul 27, 2024__: +- Latents and text encoder outputs caching mechanism is refactored significantly. + - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. + - With this change, dataset initialization is significantly faster, especially for large datasets. -Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). +- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. + +- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. + +--- `fp16` and `bf16` are available for mixed precision training. We are not sure which is better. @@ -14,7 +21,7 @@ Jun 29, 2024: Fixed mixed precision training with fp16 is not working. Fixed the `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -~~t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. ~~ t5xxl works with `fp16` now. +t5xxl works with `fp16` now. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. @@ -32,6 +39,14 @@ cache_latents = true cache_latents_to_disk = true ``` +__2024/7/27:__ + +Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 + +データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 + +SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 + --- [__Change History__](#change-history) is moved to the bottom of the page. diff --git a/fine_tune.py b/fine_tune.py index d865cd2d..c9102f6c 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -39,6 +39,7 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +import library.strategy_sd as strategy_sd def train(args): @@ -52,7 +53,15 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -81,10 +90,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -165,8 +174,9 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -192,6 +202,9 @@ def train(args): else: text_encoder.eval() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -214,7 +227,11 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -317,7 +334,9 @@ def train(args): ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -342,8 +361,9 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: + # TODO move to strategy_sd.py encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -351,10 +371,12 @@ def train(args): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -409,7 +431,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -472,7 +494,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/library/config_util.py b/library/config_util.py index 10b2457f..f8cdfe60 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False diff --git a/library/sd3_models.py b/library/sd3_models.py index ec8e1bbd..28378c73 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -38,7 +38,7 @@ class SDTokenizer: サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. """ - self.tokenizer = tokenizer + self.tokenizer: CLIPTokenizer = tokenizer self.max_length = max_length self.min_length = min_length empty = self.tokenizer("")["input_ids"] @@ -56,6 +56,19 @@ class SDTokenizer: self.inv_vocab = {v: k for k, v in vocab.items()} self.max_word_length = 8 + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + """ + Tokenize the text without weights. + """ + if type(text) == str: + text = [text] + batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") + # return tokens["input_ids"] + + pad_token = self.end_token if self.pad_with_end else 0 + for tokens in batch_tokens["input_ids"]: + assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" + def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" @@ -75,13 +88,14 @@ class SDTokenizer: for word in to_tokenize: batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) batch.append((self.end_token, 1.0)) + print(len(batch), self.max_length, self.min_length) if self.pad_to_max_length: batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) if self.min_length is not None and len(batch) < self.min_length: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - # print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}") + print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer): class SD3Tokenizer: - def __init__(self, t5xxl=True): + def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): + if t5xxl_max_length is None: + t5xxl_max_length = 256 + # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") self.t5xxl = T5XXLTokenizer() if t5xxl else None # t5xxl has 99999999 max length, clip has 77 - self.model_max_length = self.clip_l.max_length # 77 + self.t5xxl_max_length = t5xxl_max_length def tokenize_with_weights(self, text: str): - # temporary truncate to max_length even for t5xxl return ( self.clip_l.tokenize_with_weights(text), self.clip_g.tokenize_with_weights(text), ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length) + self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) if self.t5xxl is not None else None ), ) + def tokenize(self, text: str): + return ( + self.clip_l.tokenize(text), + self.clip_g.tokenize(text), + (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), + ) + # endregion @@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder: tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] list_of_tokens.append(tokens) else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] + if isinstance(list_of_token_weight_pairs[0], torch.Tensor): + list_of_tokens = [list(list_of_token_weight_pairs[0])] + else: + list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] out, pooled = self(list_of_tokens) if has_batch: @@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel): ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl ################################################################################################# - +""" class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" def __init__(self): super().__init__( @@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer): max_length=99999999, min_length=77, ) +""" class T5LayerNorm(torch.nn.Module): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 24591219..8f99d947 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -280,111 +280,6 @@ def sample_images(*args, **kwargs): return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) -class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): - SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - self.vae = None - - def set_vae(self, vae: sd3_models.SDVAE): - self.vae = vae - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) - - def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX - ) - - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - if not self.cache_to_disk: - return False - if not os.path.exists(npz_path): - return False - if self.skip_disk_cache_validity_check: - return True - - expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H) - - try: - npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False - except Exception as e: - logger.error(f"Error loading file: {npz_path}") - raise e - - return True - - def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( - image_infos, alpha_mask, random_crop - ) - img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) - - with torch.no_grad(): - latents_tensors = self.vae.encode(img_tensor).to("cpu") - if flip_aug: - img_tensor = torch.flip(img_tensor, dims=[3]) - with torch.no_grad(): - flipped_latents = self.vae.encode(img_tensor).to("cpu") - else: - flipped_latents = [None] * len(latents_tensors) - - # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): - for i in range(len(image_infos)): - info = image_infos[i] - latents = latents_tensors[i] - flipped_latent = flipped_latents[i] - alpha_mask = alpha_masks[i] - original_size = original_sizes[i] - crop_ltrb = crop_ltrbs[i] - - if self.cache_to_disk: - kwargs = {} - if flipped_latent is not None: - kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - info.latents_npz, - latents=latents.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) - else: - info.latents = latents - if flip_aug: - info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - - if not train_util.HIGH_VRAM: - clean_memory_on_device(self.vae.device) - # region Diffusers diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 16f80c60..5849518f 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -384,6 +384,7 @@ def get_cond( dtype: Optional[torch.dtype] = None, ): l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) + print(t5_tokens) return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91..f009b577 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise( ) -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 00000000..594cca5e --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,328 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + raise NotImplementedError + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _defualt_is_disk_cached_latents_expected( + self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + try: + npz = np.load(npz_path) + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + if self.cache_to_disk: + self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + npz = np.load(npz_path) + if "latents" not in npz: + raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_ltrb = npz["crop_ltrb"].tolist() + flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + ): + kwargs = {} + if flipped_latents_tensor is not None: + kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + np.savez( + npz_path, + latents=latents_tensor.float().cpu().numpy(), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), + **kwargs, + ) diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 00000000..10581614 --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,139 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from library import train_util +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + # does not include old npz + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 00000000..42630ab2 --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,229 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + clip_l, clip_g, t5xxl = models + + l_tokens, g_tokens, t5_tokens = tokens + if l_tokens is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_out, l_pooled = clip_l(l_tokens) + g_out, g_pooled = clip_g(g_tokens) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is not None and t5_tokens is not None: + t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + else: + t5_out = None + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + return [lg_out, t5_out, lg_pooled] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "clip_l" not in npz or "clip_g" not in npz: + return False + if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + return False + # t5xxl is optional + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] if "t5_out" in data else None + return [lg_out, t5_out, lg_pooled] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + captions = [info.caption for info in infos] + + clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out is not None and t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + if t5_out is not None: + t5_out = t5_out.cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] if t5_out is not None else None + lg_pooled_i = lg_pooled[i] + + if self.cache_to_disk: + kwargs = {} + if t5_out is not None: + kwargs["t5_out"] = t5_out_i + np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + else: + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for Sd3TokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = Sd3TokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 00000000..a4513336 --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,247 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, abs_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(self.get_outputs_npz_path(abs_path)): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(self.get_outputs_npz_path(abs_path)) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index 7af0070e..a747e047 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import re import shutil import time from typing import ( + Any, Dict, List, NamedTuple, @@ -34,6 +35,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -81,10 +83,6 @@ logger = logging.getLogger(__name__) # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - HIGH_VRAM = False # checkpointファイル名 @@ -148,18 +146,24 @@ class ImageInfo: self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime @@ -359,47 +363,6 @@ class AugHelper: return self.color_aug if use_color_aug else None -class LatentsCachingStrategy: - _strategy = None # strategy instance: actual strategy class - - def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: - self._cache_to_disk = cache_to_disk - self._batch_size = batch_size - self.skip_disk_cache_validity_check = skip_disk_cache_validity_check - - @classmethod - def set_strategy(cls, strategy): - if cls._strategy is not None: - raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") - cls._strategy = strategy - - @classmethod - def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: - return cls._strategy - - @property - def cache_to_disk(self): - return self._cache_to_disk - - @property - def batch_size(self): - return self._batch_size - - def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - raise NotImplementedError - - def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: - raise NotImplementedError - - def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool - ) -> bool: - raise NotImplementedError - - def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): - raise NotImplementedError - - class BaseSubset: def __init__( self, @@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset): class BaseDataset(torch.utils.data.Dataset): def __init__( self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier @@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_step: int = 0 @@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() def set_seed(self, seed): self.seed = seed @@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset): for batch_index in range(batch_count): self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - self.shuffle_buckets() self._length = len(self.buckets_indices) @@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. """ logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() image_infos = list(self.image_data.values()) # sort by resolution @@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset): logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと @@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset): for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + # if cache to disk, don't cache TE outputs in non-main process + if caching_strategy.cache_to_disk and not is_main_process: + return + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + info.text_encoder_outputs_npz = te_out_npz + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset # to support SD1/2, it needs a flag for v2, but it is postponed @@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset): # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + tokenize_strategy = TokenizeStrategy.get_strategy() + if batch_size is None: batch_size = self.batch_size @@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset): input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) batch.append((info, input_ids1, input_ids2)) else: - l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption) + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) batch.append((info, l_tokens, g_tokens, t5_tokens)) if len(batch) >= batch_size: @@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset): loss_weights = [] captions = [] input_ids_list = [] - input_ids2_list = [] latents_list = [] alpha_mask_list = [] images = [] @@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset): crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + text_encoder_outputs_list = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + ) if flipped: latents = flipped_latents alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem @@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset): # captionとtext encoder outputを処理する caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) + + if tokenization_required: caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - # TODO get_input_ids must support SD3 - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) + + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None - else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) # if one of alpha_masks is not None, we need to replace None with ones none_or_not = [x is None for x in alpha_mask_list] @@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset): self, subsets: Sequence[DreamBoothSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset): prior_loss_weight: float, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset): # new caching: get image size from cache files strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: - logger.info("get image size from cache files") + logger.info("get image size from name of cache files") size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_image_absolute_path(img_path) + w, h = strategy.get_image_size_from_disk_cache_path(img_path) if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1 @@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset): self, subsets: Sequence[FineTuningSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset): self, subsets: Sequence[ControlNetSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset): self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, batch_size, - tokenizer, - max_token_length, resolution, network_multiplier, enable_bucket, @@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset): self.conditioning_image_transforms = IMAGE_TRANSFORMS + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def new_cache_latents(self, model: Any, is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset): # for dataset in self.datasets: # dataset.make_buckets() + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) @@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy): + def new_cache_latents(self, model: Any, is_main_process: bool): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(is_main_process, strategy) + dataset.new_cache_latents(model, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset): tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, is_main_process) + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def is_text_encoder_output_cacheable(self) -> bool: return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # TODO update to use CachingStrategy -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) def debug_dataset(train_dataset, show_input_ids=False): @@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], example["captions"], example["loss_weights"], - example["input_ids"], + # example["input_ids"], example["original_sizes_hw"], example["crop_top_lefts"], example["target_sizes_hw"], @@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False): if "network_multipliers" in example: print(f"network multiplier: {example['network_multipliers'][j]}") - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] logger.info(f"image size: {im.size()}") @@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2773,14 +2783,15 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk( - info.latents_npz, - latent, - info.latents_original_size, - info.latents_crop_ltrb, - flipped_latent, - alpha_mask, - ) + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass else: info.latents = latent if flip_aug: @@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): ) -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - def prepare_accelerator(args: argparse.Namespace): """ this function also prepares deepspeed plugin @@ -5550,6 +5534,7 @@ def sample_images_common( ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here """ if steps == 0: diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index ffa0d46d..e9e61af1 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -24,7 +24,7 @@ import logging logger = logging.getLogger(__name__) -from library import sd3_models, sd3_utils +from library import sd3_models, sd3_utils, strategy_sd3 def get_noise(seed, latent): @@ -145,6 +145,7 @@ if __name__ == "__main__": parser.add_argument("--clip_g", type=str, required=False) parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -247,7 +248,7 @@ if __name__ == "__main__": # load tokenizers logger.info("Loading tokenizers...") - tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) # load models # logger.info("Create MMDiT from SD3 checkpoint...") @@ -320,12 +321,19 @@ if __name__ == "__main__": # prepare embeddings logger.info("Encoding prompts...") - # embeds, pooled_embed - lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl) - cond = torch.cat([lg_out, t5_out], dim=-2), pooled + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl) - neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + lg_out, t5_out, pooled = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # generate image logger.info("Generating image...") diff --git a/sd3_train.py b/sd3_train.py index f34e4712..617e3027 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -69,10 +69,22 @@ def train(args): # not args.train_text_encoder # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - # training without text encoder cache is not supported - assert ( - args.cache_text_encoder_outputs - ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + # # training without text encoder cache is not supported: because T5XXL must be cached + # assert ( + # args.cache_text_encoder_outputs + # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" + + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( + "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" + ) + + if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs: + logger.warning( + "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled." + + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります" + ) + args.cache_text_encoder_outputs = True # if args.block_lr: # block_lrs = [float(lr) for lr in args.block_lr.split(",")] @@ -88,17 +100,17 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # load tokenizer - sd3_tokenizer = sd3_models.SD3Tokenizer() - - # prepare caching strategy - if args.new_caching: - latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) - else: - latents_caching_strategy = None - train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # load tokenizer and prepare tokenize strategy + sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # データセットを準備する if args.dataset_class is None: @@ -153,6 +165,16 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + ) + ) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: @@ -215,19 +237,8 @@ def train(args): vae.requires_grad_(False) vae.eval() - if not args.new_caching: - vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible - with torch.no_grad(): - train_dataset_group.cache_latents( - vae_wrapper, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - file_suffix="_sd3.npz", - ) - else: - latents_caching_strategy.set_vae(vae) - train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) @@ -246,60 +257,70 @@ def train(args): t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # should be deleted after caching text encoder outputs when not training text encoder + # this strategy should not be used other than this process + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_clip_l = False train_clip_g = False train_t5xxl = False - # if args.train_text_encoder: - # # TODO each option for two text encoders? - # accelerator.print("enable text encoder training") - # if args.gradient_checkpointing: - # text_encoder1.gradient_checkpointing_enable() - # text_encoder2.gradient_checkpointing_enable() - # lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - # lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - # train_clip_l = lr_te1 != 0 - # train_clip_g = lr_te2 != 0 + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + train_clip_l = lr_te1 != 0 + train_clip_g = lr_te2 != 0 + + if not train_clip_l: + clip_l.to(weight_dtype) + if not train_clip_g: + clip_g.to(weight_dtype) + clip_l.requires_grad_(train_clip_l) + clip_g.requires_grad_(train_clip_g) + clip_l.train(train_clip_l) + clip_g.train(train_clip_g) + else: + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + clip_l.eval() + clip_g.eval() - # # caching one text encoder output is not supported - # if not train_clip_l: - # text_encoder1.to(weight_dtype) - # if not train_clip_g: - # text_encoder2.to(weight_dtype) - # text_encoder1.requires_grad_(train_clip_l) - # text_encoder2.requires_grad_(train_clip_g) - # text_encoder1.train(train_clip_l) - # text_encoder2.train(train_clip_g) - # else: - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) - clip_l.requires_grad_(False) - clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() if t5xxl is not None: t5xxl.to(t5xxl_dtype) t5xxl.requires_grad_(False) t5xxl.eval() - # TextEncoderの出力をキャッシュする + # cache text encoder outputs if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(t5xxl_device) - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs_sd3( - sd3_tokenizer, - (clip_l, clip_g, t5xxl), - (accelerator.device, accelerator.device, t5xxl_device), - None, - (None, None, None), - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - args.text_encoder_batch_size, - ) + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) - # TODO we can delete text encoders after caching + clip_l.to(accelerator.device, dtype=weight_dtype) + clip_g.to(accelerator.device, dtype=weight_dtype) + if t5xxl is not None: + t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) accelerator.wait_for_everyone() # load MMDIT @@ -332,11 +353,11 @@ def train(args): # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) # if train_clip_l: - # training_models.append(text_encoder1) - # params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + # training_models.append(clip_l) + # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) # if train_clip_g: - # training_models.append(text_encoder2) - # params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + # training_models.append(clip_g) + # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) # calculate number of trainable parameters n_params = 0 @@ -344,7 +365,7 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , text_encoder1: {train_clip_l}, text_encoder2: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -398,7 +419,11 @@ def train(args): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -455,8 +480,8 @@ def train(args): # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer # if train_clip_l: - # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # clip_l.text_model.encoder.layers[-1].requires_grad_(False) + # clip_l.text_model.final_layer_norm.requires_grad_(False) # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する if args.cache_text_encoder_outputs: @@ -484,9 +509,8 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model( args, mmdit=mmdit, - # mmdie=mmdit if train_mmdit else None, - # text_encoder1=text_encoder1 if train_clip_l else None, - # text_encoder2=text_encoder2 if train_clip_g else None, + clip_l=clip_l if train_clip_l else None, + clip_g=clip_g if train_clip_g else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -498,10 +522,10 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: mmdit = accelerator.prepare(mmdit) - # if train_clip_l: - # text_encoder1 = accelerator.prepare(text_encoder1) - # if train_clip_g: - # text_encoder2 = accelerator.prepare(text_encoder2) + if train_clip_l: + clip_l = accelerator.prepare(clip_l) + if train_clip_g: + clip_g = accelerator.prepare(clip_g) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -613,7 +637,7 @@ def train(args): # # For --sample_at_first # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], mmdit + # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit # ) # following function will be moved to sd3_train_utils @@ -666,6 +690,7 @@ def train(args): return weighting loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -687,37 +712,45 @@ def train(args): # encode images to latents. images are [-1, 1] latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = sd3_models.SDVAE.process_in(latents) - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - # not cached, get text encoder outputs - # XXX This does not work yet - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl = batch["input_ids"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + lg_out, t5_out, lg_pooled = text_encoder_outputs_list + if args.use_t5xxl_cache_only: + lg_out = None + lg_pooled = None + else: + lg_out = None + t5_out = None + lg_pooled = None + + if lg_out is None or (train_clip_l or train_clip_g): + # not cached or training, so get from text encoders + input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - # TODO support length > 75 input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) - - # get text encoder outputs: outputs are concatenated - context, pool = sd3_utils.get_cond_from_tokens( - input_ids_clip_l, input_ids_clip_g, input_ids_t5xxl, clip_l, clip_g, t5xxl + lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] ) - else: - # encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - # encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - # pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - # TODO this reuses SDXL keys, it should be fixed - lg_out = batch["text_encoder_outputs1_list"] - t5_out = batch["text_encoder_outputs2_list"] - pool = batch["text_encoder_pool2_list"] - context = torch.cat([lg_out, t5_out], dim=-2) + + if t5_out is None: + _, _, input_ids_t5xxl = batch["input_ids_list"] + with torch.no_grad(): + input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + _, t5_out, _ = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] + ) + + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -748,13 +781,13 @@ def train(args): if torch.any(torch.isnan(context)): accelerator.print("NaN found in context, replacing with zeros") context = torch.nan_to_num(context, 0, out=context) - if torch.any(torch.isnan(pool)): + if torch.any(torch.isnan(lg_pooled)): accelerator.print("NaN found in pool, replacing with zeros") - pool = torch.nan_to_num(pool, 0, out=pool) + lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled) # call model with accelerator.autocast(): - model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. @@ -806,7 +839,7 @@ def train(args): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -875,7 +908,7 @@ def train(args): # accelerator.device, # vae, # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], + # [clip_l, clip_g], # mmdit, # ) @@ -924,7 +957,19 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" + ) + # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument( + "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" + ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", + ) # TE training is disabled temporarily # parser.add_argument( @@ -962,7 +1007,6 @@ def setup_parser() -> argparse.ArgumentParser: help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う") parser.add_argument( "--skip_latents_validity_check", action="store_true", diff --git a/sdxl_train.py b/sdxl_train.py index ae92d6a3..b6d4afd6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl import library.train_util as train_util @@ -124,7 +124,16 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -166,10 +175,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -262,8 +271,9 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -276,6 +286,9 @@ def train(args): train_text_encoder1 = False train_text_encoder2 = False + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if args.train_text_encoder: # TODO each option for two text encoders? accelerator.print("enable text encoder training") @@ -307,16 +320,17 @@ def train(args): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + + accelerator.wait_for_everyone() if not cache_latents: vae.requires_grad_(False) @@ -403,7 +417,11 @@ def train(args): else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -597,7 +615,7 @@ def train(args): # For --sample_at_first sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) loss_recorder = train_util.LossRecorder() @@ -628,9 +646,15 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning # TODO support weighted captions @@ -646,39 +670,13 @@ def train(args): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -765,7 +763,7 @@ def train(args): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -847,7 +845,7 @@ def train(args): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9..0eaec29b 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -23,7 +23,16 @@ from accelerate.utils import set_seed import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) import library.model_util as model_util import library.train_util as train_util @@ -79,7 +88,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,7 +122,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -164,30 +180,30 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # prepare ControlNet-LLLite @@ -242,7 +258,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -290,7 +310,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if isinstance(unet, DDP): - unet._set_static_graph() # avoid error for multiple use of the parameter + unet._set_static_graph() # avoid error for multiple use of the parameter if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる @@ -357,7 +377,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -409,27 +431,26 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1..67ccae62 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,16 +1,21 @@ import argparse import torch +from accelerate import Accelerator from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util import train_network from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() @@ -49,15 +54,32 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -70,15 +92,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, + dataset.new_cache_text_encoder_outputs( + text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process ) + accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e2..cbfcef55 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -5,10 +5,10 @@ import regex import torch from library.device_utils import init_ipex + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util - +from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util import train_textual_inversion @@ -41,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -81,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -122,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def setup_parser() -> argparse.ArgumentParser: parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) return parser diff --git a/train_db.py b/train_db.py index 39d8ea6e..7caee664 100644 --- a/train_db.py +++ b/train_db.py @@ -11,7 +11,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device @@ -38,6 +38,7 @@ from library.custom_train_functions import ( apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd setup_logging() import logging @@ -58,7 +59,14 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -80,10 +88,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -145,13 +153,17 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 @@ -184,8 +196,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -290,10 +305,16 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -331,7 +352,7 @@ def train(args): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, + tokenize_strategy.tokenizer, text_encoder, batch["captions"], accelerator.device, @@ -339,14 +360,18 @@ def train(args): clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -358,7 +383,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -393,7 +420,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -457,7 +484,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7ba07385..3828fed1 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import random import time import json from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -18,7 +19,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -101,19 +102,31 @@ class NetworkTrainer: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return False + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + return not args.network_train_unet_only - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) @@ -123,7 +136,7 @@ class NetworkTrainer: return encoder_hidden_states def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): @@ -131,8 +144,8 @@ class NetworkTrainer: if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) def train(self, args): session_id = random.randint(0, 2**32) @@ -150,9 +163,13 @@ class NetworkTrainer: args.seed = random.randint(0, 2**32) set_seed(args.seed) - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -194,11 +211,11 @@ class NetworkTrainer: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -268,8 +285,9 @@ class NetworkTrainer: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -277,9 +295,13 @@ class NetworkTrainer: # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -366,7 +388,11 @@ class NetworkTrainer: optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -878,7 +904,7 @@ class NetworkTrainer: os.remove(old_ckpt_file) # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -933,21 +959,31 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + else: + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + # SD only + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -1026,7 +1062,9 @@ class NetworkTrainer: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1082,7 +1120,7 @@ class NetworkTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) # end of epoch diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c3..9044f50d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,6 +2,7 @@ import argparse import math import os from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -15,7 +16,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -103,28 +104,38 @@ class TextualInversionTrainer: def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + return latents_caching_strategy def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]: + return text_encoders def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -182,8 +193,13 @@ class TextualInversionTrainer: if args.seed is not None: set_seed(args.seed) - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # acceleratorを準備する logger.info("prepare accelerator") @@ -194,14 +210,7 @@ class TextualInversionTrainer: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) + model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id init_token_ids_list = [] @@ -310,10 +319,10 @@ class TextualInversionTrainer: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) + train_dataset_group = train_util.load_arbitrary_dataset(args) self.assert_extra_args(args, train_dataset_group) @@ -368,11 +377,10 @@ class TextualInversionTrainer: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -387,7 +395,11 @@ class TextualInversionTrainer: trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -415,20 +427,8 @@ class TextualInversionTrainer: lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] orig_embeds_params_list = [] @@ -456,6 +456,9 @@ class TextualInversionTrainer: else: unet.eval() + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -510,7 +513,9 @@ class TextualInversionTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -540,8 +545,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -568,7 +573,12 @@ class TextualInversionTrainer: latents = latents * self.vae_scale_factor # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -588,7 +598,9 @@ class TextualInversionTrainer: else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -639,8 +651,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -722,8 +734,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) From 1a977e847a10975c042c0fdacd871a33c9e93900 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 27 Jul 2024 13:51:50 +0900 Subject: [PATCH 091/356] fix typos --- library/strategy_base.py | 2 +- library/strategy_sd.py | 2 +- library/strategy_sd3.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 594cca5e..a99a0829 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -218,7 +218,7 @@ class LatentsCachingStrategy: def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError - def _defualt_is_disk_cached_latents_expected( + def _default_is_disk_cached_latents_expected( self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool ): if not self.cache_to_disk: diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 10581614..83ffaa31 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -125,7 +125,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 42630ab2..7491e814 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -177,7 +177,7 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): From 002d75179ae5a3b165a65c5cf49c00bf8f98e2df Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 29 Jul 2024 23:18:34 +0900 Subject: [PATCH 092/356] sample images for training --- library/sd3_train_utils.py | 348 ++++++++++++++++++++++++++++++++++++- sd3_train.py | 51 +++--- 2 files changed, 367 insertions(+), 32 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 8f99d947..da072950 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,14 +1,18 @@ import argparse -import glob import math import os -from typing import List, Optional, Tuple, Union +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union import torch from safetensors.torch import save_file -from accelerate import Accelerator +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image -from library import sd3_models, sd3_utils, train_util +from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -276,10 +280,342 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin ) -def sample_images(*args, **kwargs): - return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +# temporary copied from sd3_minimal_inferece.py +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + return x + + +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + org_vae_device = vae.device # will be on cpu + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + vae.to(org_vae_device) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = te_outputs + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: + neg_te_outputs = sample_prompts_te_outputs[negative_prompt] + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) + neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + + lg_out, t5_out, pooled = neg_te_outputs + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + + # latent to image + with torch.no_grad(): + image = vae.decode(latents) + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + # region Diffusers diff --git a/sd3_train.py b/sd3_train.py index 617e3027..2f4ea8cb 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -299,6 +299,7 @@ def train(args): t5xxl.eval() # cache text encoder outputs + sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad here clip_l.to(accelerator.device) @@ -321,6 +322,22 @@ def train(args): with accelerator.autocast(): train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_list = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + ) + accelerator.wait_for_everyone() # load MMDIT @@ -635,10 +652,8 @@ def train(args): init_kwargs=init_kwargs, ) - # # For --sample_at_first - # sd3_train_utils.sample_images( - # accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [clip_l, clip_g], mmdit - # ) + # For --sample_at_first + sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) # following function will be moved to sd3_train_utils @@ -831,17 +846,9 @@ def train(args): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -900,17 +907,9 @@ def train(args): vae, ) - # sdxl_train_util.sample_images( - # accelerator, - # args, - # epoch + 1, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [clip_l, clip_g], - # mmdit, - # ) + sd3_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) is_main_process = accelerator.is_main_process # if is_main_process: From 231df197ddf4372b3d90751146927f33e1965d1a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:26:30 +0900 Subject: [PATCH 093/356] Fix npz path for verification --- library/strategy_sdxl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index a4513336..3eb0ab6f 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -184,20 +184,20 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) + npz = np.load(npz_path) if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: return False except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True From da4d0fe0165b3e0143c237de8cf307d53a9de45a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 5 Aug 2024 20:51:34 +0900 Subject: [PATCH 094/356] support attn mask for l+g/t5 --- library/strategy_sd3.py | 88 +++++++++++++++++++++++++++++++++------- library/train_util.py | 3 +- sd3_minimal_inference.py | 10 +++-- sd3_train.py | 30 +++++++++++--- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 7491e814..a2281890 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -37,11 +37,14 @@ class Sd3TokenizeStrategy(TokenizeStrategy): g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] l_tokens = l_tokens["input_ids"] g_tokens = g_tokens["input_ids"] t5_tokens = t5_tokens["input_ids"] - return [l_tokens, g_tokens, t5_tokens] + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] class Sd3TextEncodingStrategy(TextEncodingStrategy): @@ -49,11 +52,20 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): pass def encode_tokens( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ clip_l, clip_g, t5xxl = models - l_tokens, g_tokens, t5_tokens = tokens + l_tokens, g_tokens, t5_tokens = tokens[:3] + l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] if l_tokens is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None @@ -61,10 +73,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" l_out, l_pooled = clip_l(l_tokens) g_out, g_pooled = clip_g(g_tokens) + if apply_lg_attn_mask: + l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) + g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is not None and t5_tokens is not None: t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) else: t5_out = None @@ -84,50 +101,81 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, ) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX - def is_disk_cached_outputs_expected(self, abs_path: str): + def is_disk_cached_outputs_expected(self, npz_path: str): if not self.cache_to_disk: return False - if not os.path.exists(self.get_outputs_npz_path(abs_path)): + if not os.path.exists(npz_path): return False if self.skip_disk_cache_validity_check: return True try: - npz = np.load(self.get_outputs_npz_path(abs_path)) - if "clip_l" not in npz or "clip_g" not in npz: + npz = np.load(npz_path) + if "lg_out" not in npz: return False - if "clip_l_pool" not in npz or "clip_g_pool" not in npz: + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False # t5xxl is optional except Exception as e: - logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}") + logger.error(f"Error loading file: {npz_path}") raise e return True + def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: + l_out = lg_out[..., :768] + g_out = lg_out[..., 768:] # 1280 + l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. + g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. + return np.concatenate([l_out, g_out], axis=-1) + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] t5_out = data["t5_out"] if "t5_out" in data else None + + if self.apply_lg_attn_mask: + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + + if self.apply_t5_attn_mask and t5_out is not None: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + return [lg_out, t5_out, lg_pooled] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] - clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions) + tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens] + lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask ) if lg_out.dtype == torch.bfloat16: @@ -148,10 +196,22 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): lg_pooled_i = lg_pooled[i] if self.cache_to_disk: + clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] + clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() + clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None kwargs = {} if t5_out is not None: kwargs["t5_out"] = t5_out_i - np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs) + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + clip_l_attn_mask=clip_l_attn_mask_i, + clip_g_attn_mask=clip_g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + **kwargs, + ) else: info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) diff --git a/library/train_util.py b/library/train_util.py index a747e047..fc458a88 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -646,7 +646,7 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' - + self.tokenize_strategy = None self.text_encoder_output_caching_strategy = None self.latents_caching_strategy = None @@ -1486,6 +1486,7 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) + text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index e9e61af1..630da7e0 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -146,6 +146,8 @@ if __name__ == "__main__": parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders parser.add_argument("--negative_prompt", type=str, default="") @@ -323,15 +325,15 @@ if __name__ == "__main__": logger.info("Encoding prompts...") encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt) + tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens] + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask ) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) diff --git a/sd3_train.py b/sd3_train.py index 2f4ea8cb..9c37cbce 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -172,6 +172,8 @@ def train(args): args.text_encoder_batch_size, False, False, + False, + False, ) ) train_dataset_group.set_current_strategies() @@ -312,6 +314,8 @@ def train(args): args.text_encoder_batch_size, False, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) @@ -335,7 +339,11 @@ def train(args): logger.info(f"cache Text Encoder outputs for prompt: {p}") tokens_list = sd3_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_list + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_list, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, ) accelerator.wait_for_everyone() @@ -748,21 +756,23 @@ def train(args): if lg_out is None or (train_clip_l or train_clip_g): # not cached or training, so get from text encoders - input_ids_clip_l, input_ids_clip_g, _ = batch["input_ids_list"] + input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions input_ids_clip_l = input_ids_clip_l.to(accelerator.device) input_ids_clip_g = input_ids_clip_g.to(accelerator.device) lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None] + sd3_tokenize_strategy, + [clip_l, clip_g, None], + [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], ) if t5_out is None: - _, _, input_ids_t5xxl = batch["input_ids_list"] + _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl] + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) @@ -969,6 +979,16 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) # TE training is disabled temporarily # parser.add_argument( From 36b2e6fc288c57f496a061e4d638f5641c32c9ea Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 22:56:48 +0900 Subject: [PATCH 095/356] add FLUX.1 LoRA training --- README.md | 20 + flux_minimal_inference.py | 390 ++++++++++++++++ flux_train_network.py | 332 ++++++++++++++ library/flux_models.py | 920 ++++++++++++++++++++++++++++++++++++++ library/flux_utils.py | 215 +++++++++ library/sd3_models.py | 22 +- library/strategy_flux.py | 244 ++++++++++ networks/lora_flux.py | 730 ++++++++++++++++++++++++++++++ sdxl_train_network.py | 5 + train_network.py | 169 ++++--- 10 files changed, 2992 insertions(+), 55 deletions(-) create mode 100644 flux_minimal_inference.py create mode 100644 flux_train_network.py create mode 100644 library/flux_models.py create mode 100644 library/flux_utils.py create mode 100644 library/strategy_flux.py create mode 100644 networks/lora_flux.py diff --git a/README.md b/README.md index d406fecd..a0b02f10 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,25 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## FLUX.1 LoRA training (WIP) + +__Aug 9, 2024__: + +Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +``` + +The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. + +``` +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +``` + +Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py new file mode 100644 index 00000000..f3affca8 --- /dev/null +++ b/flux_minimal_inference.py @@ -0,0 +1,390 @@ +# Minimum Inference Code for FLUX + +import argparse +import datetime +import math +import os +import random +from typing import Callable, Optional, Tuple +import einops +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +from tqdm import tqdm +from PIL import Image +import accelerate + +from library import device_utils +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import networks.lora_flux as lora_flux +from library import flux_models, flux_utils, sd3_utils, strategy_flux + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + + img = img + (t_prev - t_curr) * pred + + return img + + +def do_sample( + accelerator: Optional[accelerate.Accelerator], + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: torch.Tensor, + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, +): + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): + x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + + return x + + +def generate_image( + model, + clip_l, + t5xxl, + ae, + prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: Optional[int], + guidance: float, +): + # make first noise with packed shape + # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 + packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # prepare img and img ids + + # this is needed only for img2img + # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + # if img.shape[0] == 1 and bs > 1: + # img = repeat(img, "1 ... -> bs ...", bs=bs) + + # txt2img only needs img_ids + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare embeddings + logger.info("Encoding prompts...") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + with torch.no_grad(): + if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): + clip_l.to(clip_l_dtype) + t5xxl.to(t5xxl_dtype) + with accelerator.autocast(): + _, t5_out, txt_ids = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + # NaN check + if torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + if args.offload: + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + # del clip_l, t5xxl + device_utils.clean_memory() + + # generate image + logger.info("Generating image...") + model = model.to(device) + if steps is None: + steps = 4 if is_schnell else 50 + + img_ids = img_ids.to(device) + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype + ) + if args.offload: + model = model.cpu() + # del model + device_utils.clean_memory() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if is_fp8(ae_dtype): + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + if args.offload: + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 768 # 1024 + target_width = 1360 # 1024 + + # steps = 50 # 28 # 50 + # guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--ae", type=str, required=False) + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") + parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") + parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") + parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") + parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + guidance_scale = args.guidance + + name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way + is_schnell = name == "schnell" + + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def is_fp8(dt): + return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] + + dtype = str_to_dtype(args.dtype) + clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) + t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) + ae_dtype = str_to_dtype(args.ae_dtype, dtype) + flux_dtype = str_to_dtype(args.flux_dtype, dtype) + + logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") + + loading_device = "cpu" if args.offload else device + + use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] + if any(use_fp8): + accelerator = accelerate.Accelerator(mixed_precision="bf16") + else: + accelerator = None + + # load clip_l + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l.eval() + + logger.info(f"Loading t5xxl from {args.t5xxl}...") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) + t5xxl.eval() + + if is_fp8(clip_l_dtype): + clip_l = accelerator.prepare(clip_l) + if is_fp8(t5xxl_dtype): + t5xxl = accelerator.prepare(t5xxl) + + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # DiT + model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model.eval() + logger.info(f"Casting model to {flux_dtype}") + model.to(flux_dtype) # make sure model is dtype + if is_fp8(flux_dtype): + model = accelerator.prepare(model) + + # AE + ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae.eval() + if is_fp8(ae_dtype): + ae = accelerator.prepare(ae) + + # LoRA + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + lora_model, weights_sd = lora_flux.create_network_from_weights( + multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True + ) + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + + if not args.interactive: + generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + else: + # loop for interactive + width = target_width + height = target_height + steps = None + guidance = args.guidance + + while True: + print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + for opt in options[1:]: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + + logger.info("Done!") diff --git a/flux_train_network.py b/flux_train_network.py new file mode 100644 index 00000000..7c762c86 --- /dev/null +++ b/flux_train_network.py @@ -0,0 +1,332 @@ +import argparse +import copy +import math +import random +from typing import Any + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l.eval() + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl.eval() + + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + # if we load to cpu, flux.to(fp8) takes a long time + model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + accelerator.wait_for_everyone() + + logger.info("move text encoders back to cpu") + text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + text_encoders[1].to("cpu") # , dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + # logger.warning("Sampling images is not supported for Flux model") + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # copy from sd3_train.py and modified + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None + ): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids = text_encoder_conds + # print( + # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" + # ) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + # sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument("--clip_l", type=str, help="path to clip_l") + parser.add_argument("--t5xxl", type=str, help="path to t5xxl") + parser.add_argument("--ae", type=str, help="path to ae") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 00000000..d0955e37 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,920 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + + +from dataclasses import dataclass +import math + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # self.gradient_checkpointing = False + + # def enable_gradient_checkpointing(self): + # self.gradient_checkpointing = True + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + # def forward(self, *args, **kwargs): + # if self.training and self.gradient_checkpointing: + # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + # else: + # return self._forward(*args, **kwargs) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + # self.img_attn.enable_gradient_checkpointing() + # self.txt_attn.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + # self.img_attn.disable_gradient_checkpointing() + # self.txt_attn.disable_gradient_checkpointing() + + def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint( + # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT + # ) + # else: + # return self._forward(img, txt, vec, pe) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x, vec, pe) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 00000000..ba828d50 --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,215 @@ +import json +from typing import Union +import einops +import torch + +from safetensors.torch import load_file +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config + +from library import flux_models + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MODEL_VERSION_FLUX_V1 = "flux1" + + +def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: + logger.info(f"Bulding Flux model {name}") + with torch.device("meta"): + model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return model + + +def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: + logger.info("Building CLIP") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP: {info}") + return clip + + +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_file(ckpt_path, device=str(device)) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x diff --git a/library/sd3_models.py b/library/sd3_models.py index 28378c73..ec704dcb 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -15,6 +15,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) memory_efficient_attention = None @@ -95,7 +101,9 @@ class SDTokenizer: batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) # truncate to max_length - print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}") + print( + f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}" + ) if truncate_to_max_length and len(batch) > self.max_length: batch = batch[: self.max_length] if truncate_length is not None and len(batch) > truncate_length: @@ -1554,6 +1562,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.set_clip_options({"layer": layer_idx}) self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def gradient_checkpointing_enable(self): + logger.warning("Gradient checkpointing is not supported for this model") + def set_attn_mode(self, mode): raise NotImplementedError("This model does not support setting the attention mode") @@ -1925,6 +1944,7 @@ def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[s return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG, ) + clip_l.gradient_checkpointing_enable() if state_dict is not None: # update state_dict if provided to include logit_scale and text_projection.weight avoid errors if "logit_scale" not in state_dict: diff --git a/library/strategy_flux.py b/library/strategy_flux.py new file mode 100644 index 00000000..f194ccf6 --- /dev/null +++ b/library/strategy_flux.py @@ -0,0 +1,244 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class FluxTokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, t5_tokens, t5_attn_mask] + + +class FluxTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_t5_attn_mask: bool = False, + ) -> List[torch.Tensor]: + # supports single model inference only + + clip_l, t5xxl = models + l_tokens, t5_tokens = tokens[:2] + t5_attn_mask = tokens[2] if len(tokens) > 2 else None + + if clip_l is not None and l_tokens is not None: + l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] + else: + l_pooled = None + + if t5xxl is not None and t5_tokens is not None: + # t5_out is [1, max length, 4096] + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) + if apply_t5_attn_mask: + t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + else: + t5_out = None + txt_ids = None + + return [l_pooled, t5_out, txt_ids] + + +class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_t5_attn_mask = apply_t5_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: + return t5_out * np.expand_dims(t5_attn_mask, -1) + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + l_pooled = data["l_pooled"] + t5_out = data["t5_out"] + txt_ids = data["txt_ids"] + + if self.apply_t5_attn_mask: + t5_attn_mask = data["t5_attn_mask"] + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + + return [l_pooled, t5_out, txt_ids] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + ) + + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + + if self.cache_to_disk: + t5_attn_mask = tokens_and_masks[2] + t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + ) + else: + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + + +class FluxLatentsCachingStrategy(LatentsCachingStrategy): + FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for FluxTokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = FluxTokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 00000000..141137b4 --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,730 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + varbose=True, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + params, descriptions = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + # if self.block_lr: + # is_sdxl = False + # for lora in self.unet_loras: + # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + # is_sdxl = True + # break + + # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + # block_idx_to_lora = {} + # for lora in self.unet_loras: + # idx = get_block_index(lora.lora_name, is_sdxl) + # if idx not in block_idx_to_lora: + # block_idx_to_lora[idx] = [] + # block_idx_to_lora[idx].append(lora) + + # # blockごとにパラメータを設定する + # for idx, block_loras in block_idx_to_lora.items(): + # params, descriptions = assemble_params( + # block_loras, + # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), + # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + # ) + # all_params.extend(params) + # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) + + # else: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 67ccae62..4d6e3f18 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -52,6 +52,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.logit_scale = logit_scale self.ckpt_info = ckpt_info + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet def get_tokenize_strategy(self, args): diff --git a/train_network.py b/train_network.py index 3828fed1..48d98862 100644 --- a/train_network.py +++ b/train_network.py @@ -100,6 +100,12 @@ class NetworkTrainer: def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet def get_tokenize_strategy(self, args): @@ -147,6 +153,81 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) + # region SD/SDXL + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + return noise_pred, target, timesteps, huber_c, None + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + return loss + + # endregion + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -253,11 +334,6 @@ class NetworkTrainer: # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) accelerator.print("import network module:", args.network_module) @@ -445,16 +521,19 @@ class NetworkTrainer: unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn + unet.to(accelerator.device) # this makes faster `to(dtype)` below + unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) + unet.to(dtype=unet_weight_dtype) # this takes long time and large memory for t_enc in text_encoders: t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + if hasattr(t_enc.text_model, "embeddings"): + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -851,12 +930,7 @@ class NetworkTrainer: global_step = 0 - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) if accelerator.is_main_process: init_kwargs = {} @@ -913,6 +987,13 @@ class NetworkTrainer: initial_step -= len(train_dataloader) global_step = initial_step + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for t_enc in text_encoders: + logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -940,13 +1021,15 @@ class NetworkTrainer: else: with torch.no_grad(): # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) + latents = latents.to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor + + latents = self.shift_scale_latents(args, latents) # get multiplier for each sample if network_has_multiplier: @@ -985,41 +1068,25 @@ class NetworkTrainer: if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents + # sample noise, call unet, get target + noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, ) - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) + if weighting is not None: + loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -1027,14 +1094,8 @@ class NetworkTrainer: loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 808d2d1f48e2f4e544d47464edb2727c03da2f53 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 9 Aug 2024 23:02:51 +0900 Subject: [PATCH 096/356] fix typos --- flux_train_network.py | 2 +- library/flux_models.py | 4 ++-- library/flux_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 7c762c86..e4be97ad 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -250,7 +250,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # ) with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=packed_noisy_model_input, img_ids=img_ids, diff --git a/library/flux_models.py b/library/flux_models.py index d0955e37..92c79bcc 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -685,11 +685,11 @@ class DoubleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) - # calculate the txt bloks + # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt diff --git a/library/flux_utils.py b/library/flux_utils.py index ba828d50..166cd833 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -20,7 +20,7 @@ MODEL_VERSION_FLUX_V1 = "flux1" def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: - logger.info(f"Bulding Flux model {name}") + logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) From 358f13f2c92a04fb524006f124fc029a9edb0eaf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 14:03:59 +0900 Subject: [PATCH 097/356] fix alpha is ignored --- networks/lora_flux.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 141137b4..332a73d9 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -307,7 +307,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh module_class = LoRAInfModule if for_inference else LoRAModule - network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class) + network = LoRANetwork( + text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) return network, weights_sd @@ -331,6 +333,8 @@ class LoRANetwork(torch.nn.Module): conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -348,12 +352,15 @@ class LoRANetwork(torch.nn.Module): self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) - if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -381,13 +388,19 @@ class LoRANetwork(torch.nn.Module): dim = None alpha = None - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力 From 8a0f12dde812994ec3facdcdb7c08b362dbceb0f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 10 Aug 2024 23:42:05 +0900 Subject: [PATCH 098/356] update FLUX LoRA training --- README.md | 29 ++++++++--- flux_train_network.py | 101 +++++++++++++++++++++++++++++++------- library/sai_model_spec.py | 24 +++++++-- library/strategy_flux.py | 4 +- library/train_util.py | 9 ++-- networks/lora_flux.py | 2 +- train_network.py | 18 +++++-- 7 files changed, 148 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index a0b02f10..1089dd00 100644 --- a/README.md +++ b/README.md @@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif ## FLUX.1 LoRA training (WIP) -__Aug 9, 2024__: +This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. + +Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 ``` +LoRAs for Text Encoders are not tested yet. + +We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: + +- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. +- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). + +`--loss_type` may be useful for FLUX.1 training. The default is `l2`. + +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. + +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors +python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` -Unfortnately the training result is not good. Please let us know if you have any idea to improve the training. - ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/flux_train_network.py b/flux_train_network.py index e4be97ad..69b6e8ea 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -135,7 +135,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): pass def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler @@ -211,21 +211,32 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) + else: + t = torch.rand((bsz,), device=accelerator.device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # Add noise according to flow matching. + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -264,11 +275,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - model_pred = model_pred * (-sigmas) + noisy_model_input + if args.model_prediction_type == "raw": + # use model_pred as is + weighting = None + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + weighting = None + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -278,6 +298,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser: default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) return parser diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index af073677..ad72ec00 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -59,6 +59,8 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" ARCH_SD3_M = "stable-diffusion-3-medium" ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -66,6 +68,7 @@ ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -118,10 +121,11 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, - sd3: str = None, + sd3: Optional[str] = None, + flux: Optional[str] = None, ): """ - sd3: only supports "m" + sd3: only supports "m", flux: only supports "dev" """ # if state_dict is None, hash is not calculated @@ -140,6 +144,11 @@ def build_metadata( arch = ARCH_SD3_M else: arch = ARCH_SD3_UNKNOWN + elif flux is not None: + if flux == "dev": + arch = ARCH_FLUX_1_DEV + else: + arch = ARCH_FLUX_1_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -158,7 +167,10 @@ def build_metadata( if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if flux is not None: + # Flux + impl = IMPL_FLUX + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -216,7 +228,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl or sd3 is not None: + if sdxl or sd3 is not None or flux is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 @@ -227,7 +239,9 @@ def build_metadata( metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - if v_parameterization: + if flux is not None: + del metadata["modelspec.prediction_type"] + elif v_parameterization: metadata["modelspec.prediction_type"] = PRED_TYPE_V else: metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f194ccf6..13459d32 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -63,11 +63,11 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): l_pooled = None if t5xxl is not None and t5_tokens is not None: - # t5_out is [1, max length, 4096] + # t5_out is [b, max length, 4096] t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) if apply_t5_attn_mask: t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) - txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None diff --git a/library/train_util.py b/library/train_util.py index fc458a88..6b74bb3f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3186,6 +3186,7 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, + flux: str = None, ): timestamp = time.time() @@ -3220,6 +3221,7 @@ def get_sai_model_spec( timesteps=timesteps, clip_skip=args.clip_skip, # None or int sd3=sd3, + flux=flux, ) return metadata @@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "smooth_l1"], - help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", @@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 ): - if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 332a73d9..a4dab287 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_flux" + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index 48d98862..367203f5 100644 --- a/train_network.py +++ b/train_network.py @@ -226,6 +226,12 @@ class NetworkTrainer: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) return loss + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + # endregion def train(self, args): @@ -521,10 +527,13 @@ class NetworkTrainer: unet_weight_dtype = torch.float8_e4m3fn te_weight_dtype = torch.float8_e4m3fn - unet.to(accelerator.device) # this makes faster `to(dtype)` below + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) # this takes long time and large memory + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) @@ -718,8 +727,11 @@ class NetworkTrainer: "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, + "ss_fp8_base": args.fp8_base, } + self.update_metadata(metadata, args) # architecture specific metadata + if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time @@ -964,7 +976,7 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) From 82314ac2e7926ed15eac6306bebe4ffb78280346 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 11:14:08 +0900 Subject: [PATCH 099/356] update readme for ai toolkit settings --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1089dd00..d016bcec 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,11 @@ We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_sca `--loss_type` may be useful for FLUX.1 training. The default is `l2`. -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings. +In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. + +additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). + +Other settings may work better, so please try different settings. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. From d25ae361d06bb6f49c104ca2e6b4a9188a88c95f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 19:07:07 +0900 Subject: [PATCH 100/356] fix apply_t5_attn_mask to work --- README.md | 2 ++ flux_train_network.py | 6 ++++-- library/strategy_flux.py | 18 +++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d016bcec..d47776ca 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. + Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. diff --git a/flux_train_network.py b/flux_train_network.py index 69b6e8ea..59a666aa 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -67,14 +67,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy() + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + ) else: return None diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 13459d32..3880a1e1 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -41,17 +41,24 @@ class FluxTokenizeStrategy(TokenizeStrategy): class FluxTextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_t5_attn_mask: bool = False, + apply_t5_attn_mask: Optional[bool] = None, ) -> List[torch.Tensor]: - # supports single model inference only + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask clip_l, t5xxl = models l_tokens, t5_tokens = tokens[:2] @@ -137,8 +144,9 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # attn_mask is not applied when caching to disk: it is applied when loading from disk l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask + tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) if l_pooled.dtype == torch.bfloat16: From 74f91c2ff71035db105b218128567e6b8fa6c80d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 21:54:10 +0900 Subject: [PATCH 101/356] correct option name closes #1446 --- docs/train_README-ja.md | 2 +- docs/train_README-zh.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index d186bf24..cfa5a7d1 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 詳細については各自お調べください。 - 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。 + 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。 ### オプティマイザの指定について diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 7e00278c..1bc47e0f 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 有关详细信息,请自行研究。 - 要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。 + 要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。 ### 关于指定优化器 使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。 From 9e09a69df1ea8aa76ec98df3b2eed961c66432e4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Aug 2024 08:19:45 +0900 Subject: [PATCH 102/356] update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d47776ca..ccc83e6e 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,10 @@ Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to mak Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` LoRAs for Text Encoders are not tested yet. @@ -29,7 +29,7 @@ We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_sca In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). +additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! Other settings may work better, so please try different settings. From 4af36f96320d553025cfdf067cae1e346af44a67 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 12 Aug 2024 13:24:10 +0900 Subject: [PATCH 103/356] update to work interactive mode --- README.md | 2 ++ flux_minimal_inference.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ccc83e6e..c0d50a5a 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. +Aug 12: `--interactive` option is now working. + ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index f3affca8..b09f6380 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import datetime import math import os import random -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple import einops import numpy as np @@ -121,6 +121,9 @@ def generate_image( steps: Optional[int], guidance: float, ): + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) @@ -183,9 +186,7 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype - ) + x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) if args.offload: model = model.cpu() # del model @@ -255,6 +256,7 @@ if __name__ == "__main__": default=[], help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -341,6 +343,7 @@ if __name__ == "__main__": ae = accelerator.prepare(ae) # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] for weights_file in args.lora_weights: if ";" in weights_file: weights_file, multiplier = weights_file.split(";") @@ -351,7 +354,16 @@ if __name__ == "__main__": lora_model, weights_sd = lora_flux.create_network_from_weights( multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True ) - lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) @@ -363,7 +375,9 @@ if __name__ == "__main__": guidance = args.guidance while True: - print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + print( + "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + ) prompt = input() if prompt == "": break @@ -384,6 +398,13 @@ if __name__ == "__main__": seed = int(opt[1:].strip()) elif opt.startswith("g"): guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) From a7d5dabde3facb57d069eba0aa91e961e04303ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Aug 2024 17:09:19 +0900 Subject: [PATCH 104/356] Update readme --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index c0d50a5a..19aed221 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ We have added a new training script for LoRA training. The script is `flux_train accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` +``` + LoRAs for Text Encoders are not tested yet. We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: From 0415d200f5f3db89e33b33c9b36cb3c3e15d0266 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 13 Aug 2024 21:00:16 +0900 Subject: [PATCH 105/356] update dependencies closes #1450 --- requirements.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index e99775b8..4ee19b3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.33.0 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning @@ -38,5 +38,7 @@ imagesize==1.4.1 # open-clip-torch==2.20.0 # For logging rich==13.7.0 +# for T5XXL tokenizer (SD3/FLUX) +sentencepiece==0.2.0 # for kohya_ss library -e . From 9711c96f96038df5fa1a15d073244198b93ef0a2 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 13 Aug 2024 21:03:17 +0900 Subject: [PATCH 106/356] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 19aed221..3eb034ed 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. +__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. From 56d7651f0895c805c403a8db01083a522503eb7d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 13 Aug 2024 22:28:39 +0900 Subject: [PATCH 107/356] add experimental split mode for FLUX --- README.md | 22 +++++- flux_train_network.py | 110 +++++++++++++++++++++++---- library/flux_models.py | 165 +++++++++++++++++++++++++++++++++++++++++ networks/lora_flux.py | 30 ++++++-- 4 files changed, 304 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 3eb034ed..64b01880 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,22 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. +__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +Aug 13, 2024: + +__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. + +This argument is available even if `--split_mode` is not specified. + +__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. + +This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. + Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ - We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` @@ -19,7 +29,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +``` + +The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" ``` LoRAs for Text Encoders are not tested yet. diff --git a/flux_train_network.py b/flux_train_network.py index 59a666aa..1d1f00d8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -37,10 +37,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + # if we load to cpu, flux.to(fp8) takes a long time + model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + + if args.split_mode: + model = self.prepare_split_model(model, weight_dtype, accelerator) clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") clip_l.eval() @@ -49,13 +55,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") t5xxl.eval() - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way - # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + def prepare_split_model(self, model, weight_dtype, accelerator): + from accelerate import init_empty_weights + + logger.info("prepare split model") + with init_empty_weights(): + flux_upper = flux_models.FluxUpper(model.params) + flux_lower = flux_models.FluxLower(model.params) + sd = model.state_dict() + + # lower (trainable) + logger.info("load state dict for lower") + flux_lower.load_state_dict(sd, strict=False, assign=True) + flux_lower.to(dtype=weight_dtype) + + # upper (frozen) + logger.info("load state dict for upper") + flux_upper.load_state_dict(sd, strict=False, assign=True) + + logger.info("prepare upper model") + target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype + flux_upper.to(accelerator.device, dtype=target_dtype) + flux_upper.eval() + + if args.fp8_base: + # this is required to run on fp8 + flux_upper = accelerator.prepare(flux_upper) + + flux_upper.to("cpu") + + self.flux_upper = flux_upper + del model # we don't need model anymore + clean_memory_on_device(accelerator.device) + + logger.info("split model prepared") + + return flux_lower + def get_tokenize_strategy(self, args): return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -262,17 +302,51 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" # ) - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - ) + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -331,6 +405,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--split_mode", + action="store_true", + help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + ) # copy from Diffusers parser.add_argument( diff --git a/library/flux_models.py b/library/flux_models.py index 92c79bcc..3c7766b8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -918,3 +918,168 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + + +class FluxUpper(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + return img, txt, vec, pe + + +class FluxLower(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.out_channels = params.in_channels + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + for block in self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + for block in self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor | None = None, + pe: Tensor | None = None, + ) -> Tensor: + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a4dab287..4da33542 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -252,6 +252,11 @@ def create_network( if module_dropout is not None: module_dropout = float(module_dropout) + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -264,6 +269,7 @@ def create_network( module_dropout=module_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, + train_blocks=train_blocks, varbose=True, ) @@ -314,9 +320,11 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class LoRANetwork(torch.nn.Module): - FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" @@ -335,6 +343,7 @@ class LoRANetwork(torch.nn.Module): module_class: Type[object] = LoRAModule, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -347,6 +356,7 @@ class LoRANetwork(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -360,7 +370,9 @@ class LoRANetwork(torch.nn.Module): f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -434,9 +446,17 @@ class LoRANetwork(torch.nn.Module): skipped_te += skipped logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] - self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: From 9760d097b0bd7efbeb065d4320b2216a94e76efd Mon Sep 17 00:00:00 2001 From: DukeG Date: Wed, 14 Aug 2024 19:58:54 +0800 Subject: [PATCH 108/356] Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model' While loading T5 model in GPU. --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 367203f5..405aa747 100644 --- a/train_network.py +++ b/train_network.py @@ -540,9 +540,13 @@ class NetworkTrainer: # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc.text_model, "embeddings"): + if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): + t_enc.encoder.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: From 7db422211907df3c50703b419655202276a53301 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 14 Aug 2024 22:15:26 +0900 Subject: [PATCH 109/356] add sample image generation during training --- README.md | 2 + flux_train_network.py | 67 +++++++- library/flux_train_utils.py | 297 ++++++++++++++++++++++++++++++++++++ train_network.py | 13 +- 4 files changed, 374 insertions(+), 5 deletions(-) create mode 100644 library/flux_train_utils.py diff --git a/README.md b/README.md index 64b01880..7dc954fb 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ This feature is experimental. The options and the training script may change in __Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. + Aug 13, 2024: __Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. diff --git a/flux_train_network.py b/flux_train_network.py index 1d1f00d8..b8ea5622 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -10,7 +10,7 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network from library.utils import setup_logging @@ -28,6 +28,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + if args.cache_text_encoder_outputs: assert ( train_dataset_group.is_text_encoder_output_cacheable() @@ -139,8 +145,31 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + + # cache sample prompts + self.sample_prompts_te_outputs = None + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = sd3_train_utils.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + accelerator.wait_for_everyone() + # move back to cpu logger.info("move text encoders back to cpu") text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu") # , dtype=torch.float32) @@ -172,9 +201,36 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) # return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - # logger.warning("Sampling images is not supported for Flux model") - pass + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + if not args.split_mode: + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + ) + return + + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -389,6 +445,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): metadata["ss_model_prediction_type"] = args.model_prediction_type metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 00000000..91f52238 --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,297 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image + +from library import flux_models, flux_utils, strategy_base +from library.sd3_train_utils import load_prompts +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: List[CLIPTextModel], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + te_outputs = sample_prompts_te_outputs[prompt] + else: + tokens_and_masks = tokenize_strategy.tokenize(prompt) + te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + l_pooled, t5_out, txt_ids = te_outputs + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + + with accelerator.autocast(), torch.no_grad(): + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + + x = x.float() + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + + img = img + (t_prev - t_curr) * pred + + return img diff --git a/train_network.py b/train_network.py index 367203f5..53d71b57 100644 --- a/train_network.py +++ b/train_network.py @@ -232,6 +232,9 @@ class NetworkTrainer: def update_metadata(self, metadata, args): pass + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + # endregion def train(self, args): @@ -529,7 +532,7 @@ class NetworkTrainer: # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory - + unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) @@ -989,6 +992,14 @@ class NetworkTrainer: accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # if text_encoder is not needed for training, delete it to save memory. + # TODO this can be automated after SDXL sample prompt cache is implemented + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) From 8aaa1967bd3d3a9b4b44e97e5432d23f2101cf51 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Aug 2024 22:07:23 +0900 Subject: [PATCH 110/356] fix encoding latents closes #1456 --- flux_train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index b8ea5622..daa65c85 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -238,8 +238,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images).latent_dist.sample() - + return vae.encode(images) + def shift_scale_latents(self, args, latents): return latents From 35b6cb0cd1b319d5f34b44a8c24c81c42895fa2e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Aug 2024 22:07:35 +0900 Subject: [PATCH 111/356] update for torchvision --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7dc954fb..bdb6bf2e 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,10 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. -__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +The command to install PyTorch is as follows: +`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. From 08ef886bfeb058aa6d6f7e0a19589c0fd80b3757 Mon Sep 17 00:00:00 2001 From: DukeG Date: Fri, 16 Aug 2024 11:00:08 +0800 Subject: [PATCH 112/356] Fix AttributeError: 'FluxNetworkTrainer' object has no attribute 'sample_prompts_te_outputs' Move "self.sample_prompts_te_outputs = None" from Line 150 to Line 26. --- flux_train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index daa65c85..59b9d84b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() + self.sample_prompts_te_outputs = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -147,7 +148,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) # cache sample prompts - self.sample_prompts_te_outputs = None if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") From 3921a4efda1cd1d7d873177ea7f51b77c3f15d3d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 17:06:05 +0900 Subject: [PATCH 113/356] add t5xxl max token length, support schnell --- README.md | 8 ++++++++ flux_train_network.py | 32 ++++++++++++++++++++++++++++---- library/flux_models.py | 12 ++++++++---- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index bdb6bf2e..6fb050df 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 16, 2024: + +FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. + +Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. + +Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. + Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. Aug 13, 2024: diff --git a/flux_train_network.py b/flux_train_network.py index 59b9d84b..b9a29c16 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -44,11 +44,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + def get_flux_model_name(self, args): + return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + name = self.get_flux_model_name(args) + # if we load to cpu, flux.to(fp8) takes a long time model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") @@ -104,7 +111,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_lower def get_tokenize_strategy(self, args): - return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + name = self.get_flux_model_name(args) + + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] @@ -239,7 +257,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) - + def shift_scale_latents(self, args, latents): return latents @@ -470,7 +488,13 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) # copy from Diffusers parser.add_argument( "--weighting_scheme", diff --git a/library/flux_models.py b/library/flux_models.py index 3c7766b8..ed0bc8c7 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -863,7 +863,8 @@ class Flux(nn.Module): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.enable_gradient_checkpointing() @@ -875,7 +876,8 @@ class Flux(nn.Module): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.disable_gradient_checkpointing() @@ -972,7 +974,8 @@ class FluxUpper(nn.Module): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks: block.enable_gradient_checkpointing() @@ -984,7 +987,8 @@ class FluxUpper(nn.Module): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks: block.disable_gradient_checkpointing() From e45d3f8634c6dd4e358a8c7972f7c851f18f94d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 22:19:21 +0900 Subject: [PATCH 114/356] add merge LoRA script --- README.md | 24 +++ library/train_util.py | 2 +- networks/flux_merge_lora.py | 361 ++++++++++++++++++++++++++++++++++++ 3 files changed, 386 insertions(+), 1 deletion(-) create mode 100644 networks/flux_merge_lora.py diff --git a/README.md b/README.md index 6fb050df..e231cc24 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ The command to install PyTorch is as follows: Aug 16, 2024: +Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. @@ -80,6 +82,28 @@ Aug 12: `--interactive` option is now working. python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### Merge LoRA to FLUX.1 checkpoint + +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ + +``` +python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +``` + +You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): + +- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. +- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. + +In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. + +The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. + +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/library/train_util.py b/library/train_util.py index 59ec3e56..fa0eb9e5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3160,7 +3160,7 @@ SS_METADATA_MINIMUM_KEYS = [ def build_minimum_network_metadata( - v2: Optional[bool], + v2: Optional[str], base_model: Optional[str], network_module: str, network_dim: str, diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py new file mode 100644 index 00000000..c3986ef1 --- /dev/null +++ b/networks/flux_merge_lora.py @@ -0,0 +1,361 @@ +import math +import argparse +import os +import time +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from library import sai_model_spec, train_util +import networks.lora_flux as lora_flux +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + logger.info(f"converting to {dtype}...") + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + logger.info(f"saving to: {file_name}") + save_file(state_dict, file_name, metadata=metadata) + + +def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + return flux_state_dict + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info(f"merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata + + +def merge(args): + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + dest_dir = os.path.dirname(args.save_to) + if not os.path.exists(dest_dir): + logger.info(f"creating directory: {dest_dir}") + os.makedirs(dest_dir) + + if args.flux_model is not None: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + + else: + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--flux_model", + type=str, + default=None, + help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", + ) + parser.add_argument( + "--loading_device", + type=str, + default="cpu", + help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます", + ) + parser.add_argument( + "--working_device", + type=str, + default="cpu", + help="device to work (merge). Merging LoRA models are done on CPU." + + " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--models", + type=str, + nargs="*", + help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) From 7367584e6749448cb9b012df0d3bcbe4f0531ea5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 14:38:34 +0900 Subject: [PATCH 115/356] fix sd3 training to work without cachine TE outputs #1465 --- sd3_train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 9c37cbce..3b6c8a11 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -759,8 +759,9 @@ def train(args): input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - input_ids_clip_l = input_ids_clip_l.to(accelerator.device) - input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + # text models in sd3_models require "cpu" for input_ids + input_ids_clip_l = input_ids_clip_l.to("cpu") + input_ids_clip_g = input_ids_clip_g.to("cpu") lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], @@ -770,7 +771,7 @@ def train(args): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) From 400955d3ea4088e8da7a3917dec9b0664424e24a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 15:36:18 +0900 Subject: [PATCH 116/356] add fine tuning FLUX.1 (WIP) --- flux_train.py | 729 ++++++++++++++++++++++++++++++++++++ flux_train_network.py | 168 +-------- library/flux_train_utils.py | 270 ++++++++++++- library/train_util.py | 2 +- 4 files changed, 1007 insertions(+), 162 deletions(-) create mode 100644 flux_train.py diff --git a/flux_train.py b/flux_train.py new file mode 100644 index 00000000..2ca20ded --- /dev/null +++ b/flux_train.py @@ -0,0 +1,729 @@ +# training with captions + +import argparse +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + ) + ) + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator.is_main_process) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + + # load FLUX + # if we load to cpu, flux.to(fp8) takes a long time + flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing() + + flux.requires_grad_(True) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + flux = accelerator.prepare(flux) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"]) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids = text_encoder_conds + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + flux = accelerator.unwrap_model(flux) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index b9a29c16..002252c8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -274,85 +274,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): weight_dtype, train_unet, ): - # copy from sd3_train.py and modified - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None - ): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting - # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random t-based noise sampling - if args.timestep_sampling == "sigmoid": - # https://github.com/XLabs-AI/x-flux/tree/main - t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device)) - else: - t = torch.rand((bsz,), device=accelerator.device) - timesteps = t * 1000.0 - t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * noise - else: - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 @@ -425,20 +354,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - if args.model_prediction_type == "raw": - # use model_pred as is - weighting = None - elif args.model_prediction_type == "additive": - # add the model_pred to the noisy_model_input - model_pred = model_pred + noisy_model_input - weighting = None - elif args.model_prediction_type == "sigma_scaled": - # apply sigma scaling - model_pred = model_pred * (-sigmas) + noisy_model_input - - # these weighting schemes use a uniform timestep sampling - # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) # flow matching loss: this is different from SD3 target = noise - latents @@ -469,83 +386,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() - # sdxl_train_util.add_sdxl_training_arguments(parser) - parser.add_argument("--clip_l", type=str, help="path to clip_l") - parser.add_argument("--t5xxl", type=str, help="path to t5xxl") - parser.add_argument("--ae", type=str, help="path to ae") - parser.add_argument("--apply_t5_attn_mask", action="store_true") - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( "--split_mode", action="store_true", help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - parser.add_argument( - "--t5xxl_max_token_length", - type=int, - default=None, - help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" - " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", - ) - # copy from Diffusers - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the FLUX.1 dev variant is a guidance distilled model", - ) - - parser.add_argument( - "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], - default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", - ) - parser.add_argument( - "--sigmoid_scale", - type=float, - default=1.0, - help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', - ) - parser.add_argument( - "--model_prediction_type", - choices=["raw", "additive", "sigma_scaled"], - default="sigma_scaled", - help="How to interpret and process the model prediction: " - "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." - " / モデル予測の解釈と処理方法:" - "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", - ) - parser.add_argument( - "--discrete_flow_shift", - type=float, - default=3.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", - ) return parser diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 91f52238..167d61c7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -12,8 +12,9 @@ from accelerate import Accelerator, PartialState from transformers import CLIPTextModel from tqdm import tqdm from PIL import Image +from safetensors.torch import save_file -from library import flux_models, flux_utils, strategy_base +from library import flux_models, flux_utils, strategy_base, train_util from library.sd3_train_utils import load_prompts from library.device_utils import init_ipex, clean_memory_on_device @@ -27,6 +28,9 @@ import logging logger = logging.getLogger(__name__) +# region sample images + + def sample_images( accelerator: Accelerator, args: argparse.Namespace, @@ -295,3 +299,267 @@ def denoise( img = img + (t_prev - t_curr) * pred return img + + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) diff --git a/library/train_util.py b/library/train_util.py index fa0eb9e5..f4ac8740 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2629,7 +2629,7 @@ class MinimalDataset(BaseDataset): raise NotImplementedError -def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: +def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) dataset_class = args.dataset_class.split(".")[-1] module = importlib.import_module(module) From 25f77f6ef04ee760506338e7e7f9835c28657c59 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 17 Aug 2024 15:54:32 +0900 Subject: [PATCH 117/356] fix flux fine tuning to work --- README.md | 4 ++++ flux_train.py | 6 ++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e231cc24..2b7b110f 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` + +Aug 17. 2024: +Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. + Aug 16, 2024: Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. diff --git a/flux_train.py b/flux_train.py index 2ca20ded..d2a9b3f3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -674,9 +674,7 @@ def train(args): # if is_main_process: flux = accelerator.unwrap_model(flux) clip_l = accelerator.unwrap_model(clip_l) - clip_g = accelerator.unwrap_model(clip_g) - if t5xxl is not None: - t5xxl = accelerator.unwrap_model(t5xxl) + t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -686,7 +684,7 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) logger.info("model saved.") From 7e688913aef4c852f54a703c9f91d135b17dff87 Mon Sep 17 00:00:00 2001 From: exveria1015 Date: Sun, 18 Aug 2024 12:38:05 +0900 Subject: [PATCH 118/356] =?UTF-8?q?fix:=20Flux=20=E3=81=AE=20LoRA=20?= =?UTF-8?q?=E3=83=9E=E3=83=BC=E3=82=B8=E6=A9=9F=E8=83=BD=E3=82=92=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/flux_merge_lora.py | 362 +++++++++++++++++++++++++++++------- 1 file changed, 296 insertions(+), 66 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index c3986ef1..df0ba606 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -1,13 +1,14 @@ -import math import argparse +import math import os import time + import torch -from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm + +import lora_flux as lora_flux from library import sai_model_spec, train_util -import networks.lora_flux as lora_flux from library.utils import setup_logging setup_logging() @@ -42,34 +43,181 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): - # create module map without loading state_dict +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + + def create_key_map(n_double_layers, n_single_layers, hidden_size): + key_map = {} + for index in range(n_double_layers): + prefix_from = f"transformer_blocks.{index}" + prefix_to = f"double_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv_img = f"{prefix_to}.img_attn.qkv.{end}" + qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" + + key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) + key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) + key_map[f"{k}add_k_proj.{end}"] = ( + qkv_txt, + (0, hidden_size, hidden_size), + ) + key_map[f"{k}add_v_proj.{end}"] = ( + qkv_txt, + (0, hidden_size * 2, hidden_size), + ) + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + for index in range(n_single_layers): + prefix_from = f"single_transformer_blocks.{index}" + prefix_to = f"single_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv = f"{prefix_to}.linear1.{end}" + key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) + key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) + key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) + key_map[f"{prefix_from}.proj_mlp.{end}"] = ( + qkv, + (0, hidden_size * 3, hidden_size * 4), + ) + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + return key_map + + key_map = create_key_map( + 18, 1, 2048 + ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + + def find_matching_key(flux_dict, lora_key): + lora_key = lora_key.replace("diffusion_model.", "") + lora_key = lora_key.replace("transformer.", "") + lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up") + lora_key = lora_key.replace("single_transformer_blocks", "single_blocks") + lora_key = lora_key.replace("transformer_blocks", "double_blocks") + + double_block_map = { + "attn.to_out.0": "img_attn.proj", + "norm1.linear": "img_mod.lin", + "norm1_context.linear": "txt_mod.lin", + "attn.to_add_out": "txt_attn.proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + "attn.norm_q": "img_attn.norm.query_norm", + "attn.norm_k": "img_attn.norm.key_norm", + "attn.norm_added_q": "txt_attn.norm.query_norm", + "attn.norm_added_k": "txt_attn.norm.key_norm", + "attn.to_q": "img_attn.qkv", + "attn.to_k": "img_attn.qkv", + "attn.to_v": "img_attn.qkv", + "attn.add_q_proj": "txt_attn.qkv", + "attn.add_k_proj": "txt_attn.qkv", + "attn.add_v_proj": "txt_attn.qkv", + } + + single_block_map = { + "norm.linear": "modulation.lin", + "proj_out": "linear2", + "attn.norm_q": "norm.query_norm", + "attn.norm_k": "norm.key_norm", + "attn.to_q": "linear1", + "attn.to_k": "linear1", + "attn.to_v": "linear1", + } + + for old, new in double_block_map.items(): + lora_key = lora_key.replace(old, new) + + for old, new in single_block_map.items(): + lora_key = lora_key.replace(old, new) + + if lora_key in key_map: + flux_key = key_map[lora_key] + if isinstance(flux_key, tuple): + flux_key = flux_key[0] + logger.info(f"Found matching key: {flux_key}") + return flux_key + + # If not found in key_map, try partial matching + potential_key = lora_key + ".weight" + logger.info(f"Searching for key: {potential_key}") + matches = [k for k in flux_dict.keys() if potential_key in k] + if matches: + logger.info(f"Found matching key: {matches[0]}") + return matches[0] + return None + + merged_keys = set() for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") - lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + lora_sd, _ = load_state_dict(model, merge_dtype) - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - lora_name = key[: key.rfind(".lora_down")] - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" + if "lora_down" in key or "lora_A" in key: + lora_name = key[ + : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") + ] + up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") + alpha_key = ( + key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + + "alpha" + ) - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + logger.info(f"Processing LoRA key: {lora_name}") + flux_key = find_matching_key(flux_state_dict, lora_name) + + if flux_key is None: + logger.warning(f"no module found for LoRA weight: {key}") continue + logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}") + down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -77,40 +225,74 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim - # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = flux_state_dict[flux_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale + if lora_name.startswith("transformer."): + if "qkv" in flux_key: + hidden_size = weight.size(-1) // 3 + update = ratio * (up_weight @ down_weight) * scale - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + if "img_attn" in flux_key or "txt_attn" in flux_key: + q, k, v = torch.chunk(weight, 3, dim=-1) + if "to_q" in lora_name or "add_q_proj" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name or "add_k_proj" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name or "add_v_proj" in lora_name: + v += update.reshape(v.shape) + weight = torch.cat([q, k, v], dim=-1) + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d( + down_weight.permute(1, 0, 2, 3), up_weight + ).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) + merged_keys.add(flux_key) del up_weight del down_weight del weight + logger.info(f"Merged keys: {sorted(list(merged_keys))}") return flux_state_dict @@ -126,7 +308,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + base_model = lora_metadata.get( + train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None + ) # get alpha and dim alphas = {} # alpha for current model @@ -152,10 +336,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info( + f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" + ) # merge - logger.info(f"merging...") + logger.info("merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -173,14 +359,19 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + scale = ( + abs(scale) if "lora_up" in key else scale + ) # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None - ), f"weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + merged_sd[key].size() == lora_sd[key].size() + or concat_dim is not None + ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + merged_sd[key] = torch.cat( + [merged_sd[key], lora_sd[key] * scale], dim=concat_dim + ) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -199,7 +390,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info( + f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" + ) # check all dims are same dims_list = list(set(base_dims.values())) @@ -218,15 +411,17 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + metadata = train_util.build_minimum_network_metadata( + str(False), base_model, "networks.lora", dims, alphas, None + ) return merged_sd, metadata def merge(args): - assert len(args.models) == len( - args.ratios - ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert ( + len(args.models) == len(args.ratios) + ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -249,27 +444,48 @@ def merge(args): if args.flux_model is not None: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, ) if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + merged_from = sai_model_spec.build_merged_from( + [args.flux_model] + args.models + ) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + None, + False, + False, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) logger.info(f"saving FLUX model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models( + args.models, args.ratios, merge_dtype, args.concat, args.shuffle + ) - logger.info(f"calculating hashes and creating metadata...") + logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( + state_dict, metadata + ) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -277,7 +493,16 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + state_dict, + False, + False, + False, + True, + False, + time.time(), + title=title, + merged_from=merged_from, + flux="dev", ) metadata.update(sai_metadata) @@ -332,7 +557,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--ratios", + type=float, + nargs="*", + help="ratios for each model / それぞれのLoRAモデルの比率", + ) parser.add_argument( "--no_metadata", action="store_true", From ef535ec6bb99918027afc1e31efa72cd3761d453 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:54:18 +0900 Subject: [PATCH 119/356] add memory efficient training for FLUX.1 --- README.md | 64 +++++++++++++-- flux_train.py | 175 ++++++++++++++++++++++++++++----------- library/flux_models.py | 182 ++++++++++++++++++++++++++++++++++++----- 3 files changed, 348 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index 2b7b110f..521e82e8 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,11 @@ The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 17. 2024: +Aug 18, 2024: +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + +Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. Aug 16, 2024: @@ -39,11 +43,23 @@ Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-ge Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. + +### FLUX.1 LoRA training + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. ``` -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid +--model_prediction_type raw --guidance_scale 1.0 --loss_type l2 ``` +(The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: @@ -80,12 +96,44 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. -Aug 12: `--interactive` option is now working. - ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` +### FLUX.1 fine-tuning + +Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py +--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +``` + +(Combine the command into one line.) + +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. + +`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. + +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. + +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. + +All these options are experimental and may change in the future. + +The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. + +Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. + +The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ @@ -298,7 +346,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. @@ -308,7 +356,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. - - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. - LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! @@ -361,7 +409,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。 - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 diff --git a/flux_train.py b/flux_train.py index d2a9b3f3..ecb3c7dd 100644 --- a/flux_train.py +++ b/flux_train.py @@ -1,5 +1,15 @@ # training with captions +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse import copy import math @@ -54,6 +64,12 @@ def train(args): ) args.cache_text_encoder_outputs = True + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -232,16 +248,25 @@ def train(args): # now we can delete Text Encoders to free memory clip_l = None t5xxl = None + clean_memory_on_device(accelerator.device) # load FLUX # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: - flux.enable_gradient_checkpointing() + flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) flux.requires_grad_(True) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info( + f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" + ) + flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + if not cache_latents: # load VAE here if not cached ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") @@ -265,40 +290,43 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups. currently different learning rates are not supported grouped_params = [] - param_group = [] - param_group_lr = -1 + param_group = {} for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "single" + else: + block_idx = -1 - param_group.append(p) + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") # prepare optimizers for each group optimizers = [] @@ -307,7 +335,7 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) @@ -341,7 +369,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -414,7 +442,7 @@ def train(args): parameter.register_post_accumulate_grad_hook(__grad_hook) - elif args.fused_optimizer_groups: + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -429,22 +457,46 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + block_type, block_idx = block_types_and_indices[opt_idx] - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) - parameter.register_post_accumulate_grad_hook(optimizer_hook) + # swap blocks if necessary + if btype == "double" and double_blocks_to_swap: + if bidx >= num_double_blocks - double_blocks_to_swap: + bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + elif btype == "single" and single_blocks_to_swap: + if bidx >= num_single_blocks - single_blocks_to_swap: + bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -487,6 +539,9 @@ def train(args): init_kwargs=init_kwargs, ) + if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + flux.prepare_block_swap_before_forward() + # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -502,7 +557,7 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): @@ -591,7 +646,7 @@ def train(args): # backward accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -604,7 +659,7 @@ def train(args): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -614,7 +669,7 @@ def train(args): global_step += 1 flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) # 指定ステップごとにモデルを保存 @@ -673,8 +728,6 @@ def train(args): is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) - clip_l = accelerator.unwrap_model(clip_l) - t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() @@ -707,13 +760,43 @@ def setup_parser() -> argparse.ArgumentParser: "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index ed0bc8c7..3f44068f 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -4,6 +4,11 @@ from dataclasses import dataclass import math +from typing import Optional + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() import torch from einops import rearrange @@ -466,6 +471,33 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso # region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() @@ -648,16 +680,15 @@ class DoubleStreamBlock(nn.Module): ) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True - # self.img_attn.enable_gradient_checkpointing() - # self.txt_attn.enable_gradient_checkpointing() + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - # self.img_attn.disable_gradient_checkpointing() - # self.txt_attn.disable_gradient_checkpointing() + self.cpu_offload_checkpointing = False def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) @@ -694,11 +725,24 @@ class DoubleStreamBlock(nn.Module): txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, *args, **kwargs): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + else: - return self._forward(*args, **kwargs) + return self._forward(img, txt, vec, pe) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -747,12 +791,15 @@ class SingleStreamBlock(nn.Module): self.modulation = Modulation(hidden_size, double=False) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) @@ -768,11 +815,24 @@ class SingleStreamBlock(nn.Module): output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, *args, **kwargs): + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: if self.training and self.gradient_checkpointing: - return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) else: - return self._forward(*args, **kwargs) + return self._forward(x, vec, pe) # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -849,6 +909,9 @@ class Flux(nn.Module): self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.double_blocks_to_swap = None + self.single_blocks_to_swap = None @property def device(self): @@ -858,8 +921,9 @@ class Flux(nn.Module): def dtype(self): return next(self.parameters()).dtype - def enable_gradient_checkpointing(self): + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() @@ -867,12 +931,13 @@ class Flux(nn.Module): self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: - block.enable_gradient_checkpointing() + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) - print("FLUX: Gradient checkpointing enabled.") + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") def disable_gradient_checkpointing(self): self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() @@ -884,6 +949,24 @@ class Flux(nn.Module): print("FLUX: Gradient checkpointing disabled.") + def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): + self.double_blocks_to_swap = double_blocks + self.single_blocks_to_swap = single_blocks + + def prepare_block_swap_before_forward(self): + # move last n blocks to cpu: they are on cuda + if self.double_blocks_to_swap: + for i in range(len(self.double_blocks) - self.double_blocks_to_swap): + self.double_blocks[i].to(self.device) + for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): + self.double_blocks[i].to("cpu") # , non_blocking=True) + if self.single_blocks_to_swap: + for i in range(len(self.single_blocks) - self.single_blocks_to_swap): + self.single_blocks[i].to(self.device) + for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): + self.single_blocks[i].to("cpu") # , non_blocking=True) + clean_memory_on_device(self.device) + def forward( self, img: Tensor, @@ -910,14 +993,75 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + if not self.double_blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.double_blocks_to_swap): + block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") + + block = self.double_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved double block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.double_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved double block {block_idx} to cuda.") + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + if moving: + self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved double block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + + if not self.single_blocks_to_swap: + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + else: + # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning + for block_idx in range(self.single_blocks_to_swap): + block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] + if block.parameters().__next__().device.type != "cpu": + block.to("cpu") # , non_blocking=True) + # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") + + block = self.single_blocks[block_idx] + if block.parameters().__next__().device.type == "cpu": + block.to(self.device) + # print(f"Moved single block {block_idx} to cuda.") + + to_cpu_block_index = 0 + for block_idx, block in enumerate(self.single_blocks): + # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda + moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap + if moving: + block.to(self.device) # move to cuda + # print(f"Moved single block {block_idx} to cuda.") + + img = block(img, vec=vec, pe=pe) + + if moving: + self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) + # print(f"Moved single block {to_cpu_block_index} to cpu.") + img = img[:, txt.shape[1] :, ...] + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img From a45048892802dce43e86a7e377ba84e89b51fdf5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Aug 2024 16:56:50 +0900 Subject: [PATCH 120/356] update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 521e82e8..df2a612d 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,8 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` - Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - +Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. Aug 17, 2024: Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. @@ -118,6 +116,8 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t (Combine the command into one line.) +Sample image generation during training is not tested yet. + Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. From d034032a5dff4a5ee1a108e4f1cec41d8efadab0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 13:08:49 +0900 Subject: [PATCH 121/356] update README fix option name --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index df2a612d..9a603b28 100644 --- a/README.md +++ b/README.md @@ -105,24 +105,24 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py --pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft ---mixed_precision bf16 --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name test-bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 ---blockwise_fused_optimizer --double_blocks_to_swap 6 --cpu_offload_checkpointing +--blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing ``` (Combine the command into one line.) Sample image generation during training is not tested yet. -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizer`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. -`--blockwise_fused_optimizer` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. +`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizer`. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. From 6e72a799c8f55f148a248693d2c0c3fb1912b04e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 21:55:28 +0900 Subject: [PATCH 122/356] reduce peak VRAM usage by excluding some blocks to cuda --- flux_train.py | 15 +++++++++------ library/flux_models.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/flux_train.py b/flux_train.py index ecb3c7dd..b294ce42 100644 --- a/flux_train.py +++ b/flux_train.py @@ -251,7 +251,6 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: @@ -259,7 +258,8 @@ def train(args): flux.requires_grad_(True) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info( @@ -412,8 +412,11 @@ def train(args): training_models = [ds_model] else: - # acceleratorがなんかよろしくやってくれるらしい - flux = accelerator.prepare(flux) + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -539,7 +542,7 @@ def train(args): init_kwargs=init_kwargs, ) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + if is_swapping_blocks: flux.prepare_block_swap_before_forward() # For --sample_at_first @@ -595,7 +598,7 @@ def train(args): # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) # pack latents and get img_ids diff --git a/library/flux_models.py b/library/flux_models.py index 3f44068f..11ef647a 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -953,6 +953,22 @@ class Flux(nn.Module): self.double_blocks_to_swap = double_blocks self.single_blocks_to_swap = single_blocks + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.double_blocks_to_swap: + save_double_blocks = self.double_blocks + self.double_blocks = None + if self.single_blocks_to_swap: + save_single_blocks = self.single_blocks + self.single_blocks = None + + self.to(device) + + if self.double_blocks_to_swap: + self.double_blocks = save_double_blocks + if self.single_blocks_to_swap: + self.single_blocks = save_single_blocks + def prepare_block_swap_before_forward(self): # move last n blocks to cpu: they are on cuda if self.double_blocks_to_swap: From 486fe8f70a53166f21f08b1c896bd9ba1e31d7e7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 19 Aug 2024 22:30:24 +0900 Subject: [PATCH 123/356] feat: reduce memory usage and add memory efficient option for model saving --- README.md | 5 +++ flux_train.py | 6 +++ library/flux_train_utils.py | 21 ++++++++--- library/utils.py | 75 ++++++++++++++++++++++++++++++++++++- 4 files changed, 100 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9a603b28..51e4635b 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 19, 2024: +In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. + +An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. + Aug 18, 2024: Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_train.py b/flux_train.py index b294ce42..66996385 100644 --- a/flux_train.py +++ b/flux_train.py @@ -759,6 +759,12 @@ def setup_parser() -> argparse.ArgumentParser: add_custom_train_arguments(parser) # TODO remove this from here flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + parser.add_argument( "--fused_optimizer_groups", type=int, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 167d61c7..3f9e8660 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -20,7 +20,7 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from .utils import setup_logging +from .utils import setup_logging, mem_eff_save_file setup_logging() import logging @@ -409,19 +409,28 @@ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): return model_pred, weighting -def save_models(ckpt_path: str, flux: flux_models.Flux, sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None): +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): state_dict = {} def update_sd(prefix, sd): for k, v in sd.items(): key = prefix + k - if save_dtype is not None: + if save_dtype is not None and v.dtype != save_dtype: v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v update_sd("", flux.state_dict()) - save_file(state_dict, ckpt_path, metadata=sai_metadata) + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) def save_flux_model_on_train_end( @@ -429,7 +438,7 @@ def save_flux_model_on_train_end( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) @@ -448,7 +457,7 @@ def save_flux_model_on_epoch_end_or_stepwise( ): def sd_saver(ckpt_file, epoch_no, global_step): sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") - save_models(ckpt_file, flux, sai_metadata, save_dtype) + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) train_util.save_sd_model_on_epoch_end_or_stepwise_common( args, diff --git a/library/utils.py b/library/utils.py index 3037c055..7de22d5a 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,9 +1,12 @@ import logging import sys import threading +from typing import * +import json +import struct + import torch from torchvision import transforms -from typing import * from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput @@ -79,6 +82,76 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" Date: Tue, 20 Aug 2024 08:19:00 +0900 Subject: [PATCH 124/356] Fix debug_dataset to work --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 086b314a..cab0ec52 100644 --- a/train_network.py +++ b/train_network.py @@ -313,6 +313,7 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: From c62c95e8626bdb727cedc8f037c82ab3a8e66059 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 08:21:01 +0900 Subject: [PATCH 125/356] update about multi-resolution training in FLUX.1 --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/README.md b/README.md index 51e4635b..165eed34 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024: +FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). + +The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + +We will support multi-resolution caching to disk in the near future. + Aug 19, 2024: In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. @@ -159,6 +166,51 @@ In the case of LoRA models are trained with `bf16`, we are not sure which is bet The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. +### FLUX.1 Multi-resolution training + +You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ + +The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. + +``` +[general] +# define common settings here +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" + +[[datasets]] +# define the first resolution here +batch_size = 2 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the second resolution here +batch_size = 3 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the third resolution here +batch_size = 4 +enable_bucket = true +resolution = [512, 512] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 ``` ## SD3 training From 6f6faf9b5a99b7f741f657a06a42f63754e450c0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:16:25 +0900 Subject: [PATCH 126/356] fix to work with ai-toolkit LoRA --- networks/flux_merge_lora.py | 163 +++++++++++++++--------------------- 1 file changed, 68 insertions(+), 95 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index df0ba606..1ba1f314 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -7,8 +7,6 @@ import torch from safetensors.torch import load_file, save_file from tqdm import tqdm -import lora_flux as lora_flux -from library import sai_model_spec, train_util from library.utils import setup_logging setup_logging() @@ -16,6 +14,9 @@ import logging logger = logging.getLogger(__name__) +import lora_flux as lora_flux +from library import sai_model_spec, train_util + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -43,13 +44,11 @@ def save_to_file(file_name, state_dict, dtype, metadata): save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype -): +def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) - def create_key_map(n_double_layers, n_single_layers, hidden_size): + def create_key_map(n_double_layers, n_single_layers): key_map = {} for index in range(n_double_layers): prefix_from = f"transformer_blocks.{index}" @@ -60,18 +59,12 @@ def merge_to_flux_model( qkv_img = f"{prefix_to}.img_attn.qkv.{end}" qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" - key_map[f"{k}to_q.{end}"] = (qkv_img, (0, 0, hidden_size)) - key_map[f"{k}to_k.{end}"] = (qkv_img, (0, hidden_size, hidden_size)) - key_map[f"{k}to_v.{end}"] = (qkv_img, (0, hidden_size * 2, hidden_size)) - key_map[f"{k}add_q_proj.{end}"] = (qkv_txt, (0, 0, hidden_size)) - key_map[f"{k}add_k_proj.{end}"] = ( - qkv_txt, - (0, hidden_size, hidden_size), - ) - key_map[f"{k}add_v_proj.{end}"] = ( - qkv_txt, - (0, hidden_size * 2, hidden_size), - ) + key_map[f"{k}to_q.{end}"] = qkv_img + key_map[f"{k}to_k.{end}"] = qkv_img + key_map[f"{k}to_v.{end}"] = qkv_img + key_map[f"{k}add_q_proj.{end}"] = qkv_txt + key_map[f"{k}add_k_proj.{end}"] = qkv_txt + key_map[f"{k}add_v_proj.{end}"] = qkv_txt block_map = { "attn.to_out.0.weight": "img_attn.proj.weight", @@ -106,13 +99,10 @@ def merge_to_flux_model( for end in ("weight", "bias"): k = f"{prefix_from}.attn." qkv = f"{prefix_to}.linear1.{end}" - key_map[f"{k}to_q.{end}"] = (qkv, (0, 0, hidden_size)) - key_map[f"{k}to_k.{end}"] = (qkv, (0, hidden_size, hidden_size)) - key_map[f"{k}to_v.{end}"] = (qkv, (0, hidden_size * 2, hidden_size)) - key_map[f"{prefix_from}.proj_mlp.{end}"] = ( - qkv, - (0, hidden_size * 3, hidden_size * 4), - ) + key_map[f"{k}to_q.{end}"] = qkv + key_map[f"{k}to_k.{end}"] = qkv + key_map[f"{k}to_v.{end}"] = qkv + key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv block_map = { "norm.linear.weight": "modulation.lin.weight", @@ -126,11 +116,14 @@ def merge_to_flux_model( for k, v in block_map.items(): key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + # add as-is keys + values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())]) + values.sort() + key_map.update({v: v for v in values}) + return key_map - key_map = create_key_map( - 18, 1, 2048 - ) # Assuming 18 double layers, 1 single layer, and hidden size of 2048 + key_map = create_key_map(18, 38) # 18 double layers, 38 single layers def find_matching_key(flux_dict, lora_key): lora_key = lora_key.replace("diffusion_model.", "") @@ -159,7 +152,6 @@ def merge_to_flux_model( "attn.add_k_proj": "txt_attn.qkv", "attn.add_v_proj": "txt_attn.qkv", } - single_block_map = { "norm.linear": "modulation.lin", "proj_out": "linear2", @@ -168,18 +160,22 @@ def merge_to_flux_model( "attn.to_q": "linear1", "attn.to_k": "linear1", "attn.to_v": "linear1", + "proj_mlp": "linear1", } + # same key exists in both single_block_map and double_block_map, so we must care about single/double + # print("lora_key before double_block_map", lora_key) for old, new in double_block_map.items(): - lora_key = lora_key.replace(old, new) - + if "double" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key before single_block_map", lora_key) for old, new in single_block_map.items(): - lora_key = lora_key.replace(old, new) + if "single" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key after mapping", lora_key) if lora_key in key_map: flux_key = key_map[lora_key] - if isinstance(flux_key, tuple): - flux_key = flux_key[0] logger.info(f"Found matching key: {flux_key}") return flux_key @@ -198,16 +194,11 @@ def merge_to_flux_model( lora_sd, _ = load_state_dict(model, merge_dtype) logger.info("merging...") - for key in tqdm(lora_sd.keys()): + for key in lora_sd.keys(): if "lora_down" in key or "lora_A" in key: - lora_name = key[ - : key.rfind(".lora_down" if "lora_down" in key else ".lora_A") - ] + lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")] up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") - alpha_key = ( - key[: key.index("lora_down" if "lora_down" in key else "lora_A")] - + "alpha" - ) + alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha" logger.info(f"Processing LoRA key: {lora_name}") flux_key = find_matching_key(flux_state_dict, lora_name) @@ -231,20 +222,35 @@ def merge_to_flux_model( up_weight = up_weight.to(working_device, merge_dtype) down_weight = down_weight.to(working_device, merge_dtype) + # print(up_weight.size(), down_weight.size(), weight.size()) + if lora_name.startswith("transformer."): - if "qkv" in flux_key: - hidden_size = weight.size(-1) // 3 + if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp update = ratio * (up_weight @ down_weight) * scale + # print(update.shape) if "img_attn" in flux_key or "txt_attn" in flux_key: - q, k, v = torch.chunk(weight, 3, dim=-1) + q, k, v = torch.chunk(weight, 3, dim=0) if "to_q" in lora_name or "add_q_proj" in lora_name: q += update.reshape(q.shape) elif "to_k" in lora_name or "add_k_proj" in lora_name: k += update.reshape(k.shape) elif "to_v" in lora_name or "add_v_proj" in lora_name: v += update.reshape(v.shape) - weight = torch.cat([q, k, v], dim=-1) + weight = torch.cat([q, k, v], dim=0) + elif "linear1" in flux_key: + q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0) + mlp = weight[int(update.shape[-1] * 3) :] + # print(q.shape, k.shape, v.shape, mlp.shape) + if "to_q" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name: + v += update.reshape(v.shape) + elif "proj_mlp" in lora_name: + mlp += update.reshape(mlp.shape) + weight = torch.cat([q, k, v, mlp], dim=0) else: if len(weight.size()) == 2: weight = weight + ratio * (up_weight @ down_weight) * scale @@ -252,18 +258,11 @@ def merge_to_flux_model( weight = ( weight + ratio - * ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale ) else: - conved = torch.nn.functional.conv2d( - down_weight.permute(1, 0, 2, 3), up_weight - ).permute(1, 0, 2, 3) + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale else: if len(weight.size()) == 2: @@ -272,18 +271,11 @@ def merge_to_flux_model( weight = ( weight + ratio - * ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale ) else: - conved = torch.nn.functional.conv2d( - down_weight.permute(1, 0, 2, 3), up_weight - ).permute(1, 0, 2, 3) + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) @@ -308,9 +300,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_metadata is not None: if base_model is None: - base_model = lora_metadata.get( - train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None - ) + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # get alpha and dim alphas = {} # alpha for current model @@ -336,9 +326,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - logger.info( - f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}" - ) + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge logger.info("merging...") @@ -359,19 +347,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = ( - abs(scale) if "lora_up" in key else scale - ) # マイナスの重みに対応する。 + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() - or concat_dim is not None + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" if concat_dim is not None: - merged_sd[key] = torch.cat( - [merged_sd[key], lora_sd[key] * scale], dim=concat_dim - ) + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) else: merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: @@ -390,9 +373,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") - logger.info( - f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}" - ) + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -411,16 +392,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): # build minimum metadata dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata( - str(False), base_model, "networks.lora", dims, alphas, None - ) + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) return merged_sd, metadata def merge(args): - assert ( - len(args.models) == len(args.ratios) + assert len(args.models) == len( + args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): @@ -456,9 +435,7 @@ def merge(args): if args.no_metadata: sai_metadata = None else: - merged_from = sai_model_spec.build_merged_from( - [args.flux_model] + args.models - ) + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( None, @@ -477,15 +454,11 @@ def merge(args): save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) else: - state_dict, metadata = merge_lora_models( - args.models, args.ratios, merge_dtype, args.concat, args.shuffle - ) + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes( - state_dict, metadata - ) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash From 9381332020b7089a41eb8d041938f8ba417528d1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:32:26 +0900 Subject: [PATCH 127/356] revert merge function add add option to use new func --- README.md | 3 + networks/flux_merge_lora.py | 120 +++++++++++++++++++++++++++--------- 2 files changed, 94 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 165eed34..3f5c4daa 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 2): +`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! + Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 1ba1f314..fd9cc4e3 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -4,6 +4,7 @@ import os import time import torch +from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -45,6 +46,81 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): + # create module map without loading state_dict + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + lora_name_to_module_key = {} + with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name not in lora_name_to_module_key: + logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + continue + + down_weight = lora_sd.pop(key) + up_weight = lora_sd.pop(up_key) + + dim = down_weight.size()[0] + alpha = lora_sd.pop(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + module_weight_key = lora_name_to_module_key[lora_name] + if module_weight_key not in flux_state_dict: + weight = flux_file.get_tensor(module_weight_key) + else: + weight = flux_state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + if len(lora_sd) > 0: + logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") + + return flux_state_dict + + +def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): logger.info(f"loading keys from FLUX.1 model: {flux_model}") flux_state_dict = load_file(flux_model, device=loading_device) @@ -422,15 +498,14 @@ def merge(args): os.makedirs(dest_dir) if args.flux_model is not None: - state_dict = merge_to_flux_model( - args.loading_device, - args.working_device, - args.flux_model, - args.models, - args.ratios, - merge_dtype, - save_dtype, - ) + if not args.diffusers: + state_dict = merge_to_flux_model( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) + else: + state_dict = merge_to_flux_model_diffusers( + args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + ) if args.no_metadata: sai_metadata = None @@ -438,16 +513,7 @@ def merge(args): merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, - False, - False, - False, - False, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) logger.info(f"saving FLUX model to: {args.save_to}") @@ -466,16 +532,7 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, - False, - False, - False, - True, - False, - time.time(), - title=title, - merged_from=merged_from, - flux="dev", + state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) @@ -553,6 +610,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) + parser.add_argument( + "--diffusers", + action="store_true", + help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする", + ) return parser From dbed5126bd1133da832dae31ce73ba6c41afc9d3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 19:33:47 +0900 Subject: [PATCH 128/356] chore: formatting --- networks/flux_merge_lora.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index fd9cc4e3..d5e82920 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -113,7 +113,7 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati del up_weight del down_weight del weight - + if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") @@ -587,12 +587,7 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", ) - parser.add_argument( - "--ratios", - type=float, - nargs="*", - help="ratios for each model / それぞれのLoRAモデルの比率", - ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( "--no_metadata", action="store_true", From 6ab48b09d8e46973d5e5fa47baeae3a464d06d04 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 21:39:43 +0900 Subject: [PATCH 129/356] feat: Support multi-resolution training with caching latents to disk --- README.md | 11 +++- library/strategy_base.py | 112 ++++++++++++++++++++++++++------------- library/strategy_flux.py | 11 +++- library/train_util.py | 2 +- 4 files changed, 93 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 3f5c4daa..1d44c9e5 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 3): +__Experimental__ The multi-resolution training is now supported with caching latents to disk. + +The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). + +See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + Aug 20, 2024 (update 2): `flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). -The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. +The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. We will support multi-resolution caching to disk in the near future. @@ -171,7 +178,7 @@ The script can merge multiple LoRA models. If you want to merge multiple LoRA mo ### FLUX.1 Multi-resolution training -You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ +You can define multiple resolutions in the dataset configuration file. The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. diff --git a/library/strategy_base.py b/library/strategy_base.py index a99a0829..e7d3a97e 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -219,7 +219,13 @@ class LatentsCachingStrategy: raise NotImplementedError def _default_is_disk_cached_latents_expected( - self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + multi_resolution: bool = False, ): if not self.cache_to_disk: return False @@ -230,25 +236,17 @@ class LatentsCachingStrategy: expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + try: npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -257,7 +255,15 @@ class LatentsCachingStrategy: # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( - self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. @@ -287,8 +293,13 @@ class LatentsCachingStrategy: original_size = original_sizes[i] crop_ltrb = crop_ltrbs[i] + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + if self.cache_to_disk: - self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) else: info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -298,31 +309,56 @@ class LatentsCachingStrategy: info.alpha_mask = alpha_mask def load_latents_from_disk( - self, npz_path: str + self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + """ + for SD/SDXL/SD3.0 + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( - self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", ): kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 3880a1e1..5c620f3d 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -200,7 +200,12 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -208,7 +213,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index f4ac8740..8929c192 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1381,7 +1381,7 @@ class BaseDataset(torch.utils.data.Dataset): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) ) if flipped: latents = flipped_latents From 7e459c00b2e142e40a9452341934c2eb9f70a172 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 08:02:33 +0900 Subject: [PATCH 130/356] Update T5 attention mask handling in FLUX --- README.md | 3 +++ flux_minimal_inference.py | 33 +++++++++++++++++++----- flux_train.py | 6 ++++- flux_train_network.py | 13 +++++----- library/flux_models.py | 51 +++++++++++++++++++++---------------- library/flux_train_utils.py | 20 ++++++++++++--- library/strategy_flux.py | 25 ++++++++++-------- 7 files changed, 101 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 1d44c9e5..43edbbed 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024: +The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. + Aug 20, 2024 (update 3): __Experimental__ The multi-resolution training is now supported with caching latents to disk. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index b09f6380..5b8aa250 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -70,12 +70,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -92,6 +102,7 @@ def do_sample( txt_ids: torch.Tensor, num_steps: int, guidance: float, + t5_attn_mask: Optional[torch.Tensor], is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, @@ -101,10 +112,14 @@ def do_sample( # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): - x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance) + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) return x @@ -156,14 +171,14 @@ def generate_image( clip_l.to(clip_l_dtype) t5xxl.to(t5xxl_dtype) with accelerator.autocast(): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) else: with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids = encoding_strategy.encode_tokens( + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) @@ -186,7 +201,11 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) + t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + + x = do_sample( + accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + ) if args.offload: model = model.cpu() # del model diff --git a/flux_train.py b/flux_train.py index 66996385..ecb8a108 100644 --- a/flux_train.py +++ b/flux_train.py @@ -610,7 +610,10 @@ def train(args): guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) # call model - l_pooled, t5_out, txt_ids = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( @@ -621,6 +624,7 @@ def train(args): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # unpack latents diff --git a/flux_train_network.py b/flux_train_network.py index 002252c8..49bd270c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -233,11 +233,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_lower = flux_lower self.target_device = device - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None): + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) @@ -300,10 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec.requires_grad_(True) # Predict the noise residual - l_pooled, t5_out, txt_ids = text_encoder_conds - # print( - # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" - # ) + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None if not args.split_mode: # normal forward @@ -317,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) else: # split forward to reduce memory usage @@ -337,6 +337,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): y=l_pooled, timesteps=timesteps / 1000, guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, ) # move flux upper back to cpu, and then move flux lower to gpu diff --git a/library/flux_models.py b/library/flux_models.py index 11ef647a..6f28da60 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -440,10 +440,10 @@ configs = { # region math -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x @@ -607,11 +607,7 @@ class SelfAttention(nn.Module): self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) - # self.gradient_checkpointing = False - - # def enable_gradient_checkpointing(self): - # self.gradient_checkpointing = True - + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) @@ -620,12 +616,6 @@ class SelfAttention(nn.Module): x = self.proj(x) return x - # def forward(self, *args, **kwargs): - # if self.training and self.gradient_checkpointing: - # return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) - # else: - # return self._forward(*args, **kwargs) - @dataclass class ModulationOut: @@ -690,7 +680,9 @@ class DoubleStreamBlock(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -713,7 +705,18 @@ class DoubleStreamBlock(nn.Module): k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + attn_mask = txt_attention_mask # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img blocks @@ -725,10 +728,12 @@ class DoubleStreamBlock(nn.Module): txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False) + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing def create_custom_forward(func): @@ -739,10 +744,10 @@ class DoubleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) else: - return self._forward(img, txt, vec, pe) + return self._forward(img, txt, vec, pe, txt_attention_mask) # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): # if self.training and self.gradient_checkpointing: @@ -992,6 +997,7 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1011,7 +1017,7 @@ class Flux(nn.Module): if not self.double_blocks_to_swap: for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.double_blocks_to_swap): @@ -1033,7 +1039,7 @@ class Flux(nn.Module): block.to(self.device) # move to cuda # print(f"Moved double block {block_idx} to cuda.") - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1164,6 +1170,7 @@ class FluxUpper(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1182,7 +1189,7 @@ class FluxUpper(nn.Module): pe = self.pe_embedder(ids) for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) return img, txt, vec, pe diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 3f9e8660..1d3f80d7 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -190,9 +190,10 @@ def sample_image_inference( te_outputs = sample_prompts_te_outputs[prompt] else: tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - l_pooled, t5_out, txt_ids = te_outputs + l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -208,9 +209,10 @@ def sample_image_inference( ) timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -289,12 +291,22 @@ def denoise( vec: torch.Tensor, timesteps: list[float], guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) img = img + (t_prev - t_curr) * pred @@ -498,7 +510,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--apply_t5_attn_mask", action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5c620f3d..737af390 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -64,22 +64,25 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None + # clip_l is None when using T5 only if clip_l is not None and l_tokens is not None: l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] else: l_pooled = None + # t5xxl is None when using CLIP only if t5xxl is not None and t5_tokens is not None: # t5_out is [b, max length, 4096] - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True) - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) else: t5_out = None txt_ids = None - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -115,6 +118,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "txt_ids" not in npz: return False + if "t5_attn_mask" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -129,12 +134,12 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_attn_mask = data["t5_attn_mask"] t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) - return [l_pooled, t5_out, txt_ids] + return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -145,7 +150,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): # attn_mask is not applied when caching to disk: it is applied when loading from disk - l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens( + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) @@ -159,15 +164,15 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): l_pooled = l_pooled.cpu().numpy() t5_out = t5_out.cpu().numpy() txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() for i, info in enumerate(infos): l_pooled_i = l_pooled[i] t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] if self.cache_to_disk: - t5_attn_mask = tokens_and_masks[2] - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() np.savez( info.text_encoder_outputs_npz, l_pooled=l_pooled_i, @@ -176,7 +181,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_attn_mask=t5_attn_mask_i, ) else: - info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i) + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) class FluxLatentsCachingStrategy(LatentsCachingStrategy): From e17c42cb0de8a1303a607ecc75af092dc12dc272 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 12:28:45 +0900 Subject: [PATCH 131/356] Add BFL/Diffusers LoRA converter #1467 #1458 #1483 --- networks/convert_flux_lora.py | 403 ++++++++++++++++++++++++++++++++++ 1 file changed, 403 insertions(+) create mode 100644 networks/convert_flux_lora.py diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py new file mode 100644 index 00000000..dd962ebf --- /dev/null +++ b/networks/convert_flux_lora.py @@ -0,0 +1,403 @@ +# convert key mapping and data format from some LoRA format to another +""" +Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module +alpha is scalar for each LoRA module + +0 to 18 +lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4]) +lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4]) + +0 to 37 +lora_unet_single_blocks_0_linear1.alpha torch.Size([]) +lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4]) +lora_unet_single_blocks_0_linear2.alpha torch.Size([]) +lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360]) +lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4]) +lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([]) +lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4]) +""" +""" +ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules. +A is down, B is up. No alpha for each LoRA module. + +0 to 18 +transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16]) +transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16]) + +0 to 37 +transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16]) +transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16]) +transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360]) +transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16]) +""" +""" +xlabs: Unknown format. +0 to 18 +double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16]) +double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16]) +""" + + +import argparse +from safetensors.torch import save_file +from safetensors import safe_open +import torch + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key): + ait_down_key = ait_key + ".lora_A.weight" + if ait_down_key not in ait_sd: + return + ait_up_key = ait_key + ".lora_B.weight" + + down_weight = ait_sd.pop(ait_down_key) + sds_sd[sds_key + ".lora_down.weight"] = down_weight + sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key) + rank = down_weight.shape[0] + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device) + + +def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys): + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + if ait_down_keys[0] not in ait_sd: + return + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + down_weights = [ait_sd.pop(k) for k in ait_down_keys] + up_weights = [ait_sd.pop(k) for k in ait_up_keys] + + # lora_down is concatenated along dim=0, so rank is multiplied by the number of splits + rank = down_weights[0].shape[0] + num_splits = len(ait_keys) + sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0) + + merged_up_weights = torch.zeros( + (sum(w.shape[0] for w in up_weights), rank * num_splits), + dtype=up_weights[0].dtype, + device=up_weights[0].device, + ) + + i = 0 + for j, up_weight in enumerate(up_weights): + merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight + i += up_weight.shape[0] + + sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights + + # set alpha to new_rank + new_rank = rank * num_splits + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device) + + +def convert_ai_toolkit_to_sd_scripts(ait_sd): + sds_sd = {} + for i in range(19): + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(ait_sd) > 0: + logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}") + return sds_sd + + +def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + # print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}") + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + +def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + num_splits = len(ait_keys) + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + + # down_weight is copied to each split + ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + + # calculate dims if not provided + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # up_weight is split to each split + ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + + +def convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + return ait_sd + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting {args.src} to {args.dst} format") + if args.src == "ai-toolkit" and args.dst == "sd-scripts": + state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) + elif args.src == "sd-scripts" and args.dst == "ai-toolkit": + state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + else: + raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts") + parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts") + parser.add_argument("--src_path", type=str, default=None, help="source path") + parser.add_argument("--dst_path", type=str, default=None, help="destination path") + args = parser.parse_args() + main(args) From 2b07a92c8d970a8538a47dd1bcad3122da4e195a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 12:30:23 +0900 Subject: [PATCH 132/356] Fix error in applying mask in Attention and add LoRA converter script --- README.md | 6 ++++++ library/flux_models.py | 5 +++-- networks/convert_flux_lora.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 43edbbed..f4056851 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 2): +Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. + +Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. + + Aug 21, 2024: The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. diff --git a/library/flux_models.py b/library/flux_models.py index 6f28da60..e38119cd 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -708,9 +708,10 @@ class DoubleStreamBlock(nn.Module): # make attention mask if not None attn_mask = None if txt_attention_mask is not None: - attn_mask = txt_attention_mask # b, seq_len + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len attn_mask = torch.cat( - (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1]).to(attn_mask.device)), dim=1 + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 ) # b, seq_len + img_len # broadcast attn_mask to all heads diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index dd962ebf..e9743534 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -248,7 +248,7 @@ def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + # print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 scale_down = scale From e1cd19c0c0ef55709e8eb1e5babe25045f65031f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 21 Aug 2024 21:04:10 +0900 Subject: [PATCH 133/356] add stochastic rounding, fix single block --- README.md | 19 ++++++-- flux_train.py | 93 ++++++++++++++++++++++++++++++++++---- library/adafactor_fused.py | 36 ++++++++++++++- library/flux_models.py | 1 + 4 files changed, 135 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f4056851..45349ba3 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 21, 2024 (update 3): +- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ +- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is +based on the code provided by 2kpr. Thank you so much! + - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. + - Please note that `--fused_backward_pass` is only supported with Adafactor. +- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. +- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. + Aug 21, 2024 (update 2): Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. @@ -142,7 +151,7 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 ---blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing +--fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` (Combine the command into one line.) @@ -151,9 +160,13 @@ Sample image generation during training is not tested yet. Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. -`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training. +`--full_bf16` enables the training with bf16 (weights and gradients). -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`. +`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. + +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. + +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. diff --git a/flux_train.py b/flux_train.py index ecb8a108..bcf4b956 100644 --- a/flux_train.py +++ b/flux_train.py @@ -277,7 +277,10 @@ def train(args): training_models = [] params_to_optimize = [] training_models.append(flux) - params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate}) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] # calculate number of trainable parameters n_params = 0 @@ -433,17 +436,89 @@ def train(args): import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: + + double_blocks_to_swap = args.double_blocks_to_swap + single_blocks_to_swap = args.single_blocks_to_swap + num_double_blocks = len(flux.double_blocks) + num_single_blocks = len(flux.single_blocks) + handled_double_block_indices = set() + handled_single_block_indices = set() + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: + grad_hook = None - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + if double_blocks_to_swap: + if param_name.startswith("double_blocks"): + block_idx = int(param_name.split(".")[1]) + if ( + block_idx not in handled_double_block_indices + and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 + and block_idx < num_double_blocks - 1 + ): + # swap next (already backpropagated) block + handled_double_block_indices.add(block_idx) + block_idx_cpu = block_idx + 1 + block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) - parameter.register_post_accumulate_grad_hook(__grad_hook) + # create swap hook + def create_double_swap_grad_hook(bidx, bidx_cuda): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # swap blocks if necessary + flux.double_blocks[bidx].to("cpu") + flux.double_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + + return __grad_hook + + grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) + if single_blocks_to_swap: + if param_name.startswith("single_blocks"): + block_idx = int(param_name.split(".")[1]) + if ( + block_idx not in handled_single_block_indices + and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 + and block_idx < num_single_blocks - 1 + ): + handled_single_block_indices.add(block_idx) + block_idx_cpu = block_idx + 1 + block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) + # print(param_name, block_idx_cpu, block_idx_cuda) + + # create swap hook + def create_single_swap_grad_hook(bidx, bidx_cuda): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + # swap blocks if necessary + flux.single_blocks[bidx].to("cpu") + flux.single_blocks[bidx_cuda].to(accelerator.device) + # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + + return __grad_hook + + grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + + if grad_hook is None: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + grad_hook = __grad_hook + + parameter.register_post_accumulate_grad_hook(grad_hook) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index bdfc32ce..b5afa236 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -2,6 +2,32 @@ import math import torch from transformers import Adafactor +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -48,7 +74,7 @@ def adafactor_step_param(self, p, group): lr = Adafactor._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad ** 2) + group["eps"][0] + update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] @@ -78,7 +104,12 @@ def adafactor_step_param(self, p, group): p_data_fp32.add_(-update) - if p.dtype in {torch.float16, torch.bfloat16}: + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: p.copy_(p_data_fp32) @@ -101,6 +132,7 @@ def adafactor_step(self, closure=None): return loss + def patch_adafactor_fused(optimizer: Adafactor): optimizer.step_param = adafactor_step_param.__get__(optimizer) optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/flux_models.py b/library/flux_models.py index e38119cd..c98d52ec 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1078,6 +1078,7 @@ class Flux(nn.Module): if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) # print(f"Moved single block {to_cpu_block_index} to cpu.") + to_cpu_block_index += 1 img = img[:, txt.shape[1] :, ...] From 98c91a762513bbce9ebce137da720a448a3da6c9 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 22 Aug 2024 12:37:41 +0900 Subject: [PATCH 134/356] Fix bug in FLUX multi GPU training --- README.md | 6 +++ flux_train.py | 29 ++++++------- flux_train_network.py | 10 +++-- library/flux_models.py | 6 ++- library/flux_utils.py | 40 ++++++++++++++---- library/strategy_flux.py | 4 +- library/train_util.py | 10 ++--- library/utils.py | 89 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 156 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 45349ba3..5125c663 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024: +Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. + + Aug 21, 2024 (update 3): - There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ - Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is diff --git a/flux_train.py b/flux_train.py index bcf4b956..e7d45e04 100644 --- a/flux_train.py +++ b/flux_train.py @@ -174,7 +174,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -199,8 +199,8 @@ def train(args): strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) # load clip_l, t5xxl for caching text encoder outputs - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) clip_l.eval() t5xxl.eval() clip_l.requires_grad_(False) @@ -228,7 +228,6 @@ def train(args): if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() prompts = load_prompts(args.sample_prompts) @@ -238,9 +237,9 @@ def train(args): for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) + tokens_and_masks = flux_tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) accelerator.wait_for_everyone() @@ -251,7 +250,9 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + flux = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) if args.gradient_checkpointing: flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) @@ -419,7 +420,7 @@ def train(args): # if we doesn't swap blocks, we can move the model to device flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) if is_swapping_blocks: - flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -439,8 +440,8 @@ def train(args): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) handled_double_block_indices = set() handled_single_block_indices = set() @@ -537,8 +538,8 @@ def train(args): double_blocks_to_swap = args.double_blocks_to_swap single_blocks_to_swap = args.single_blocks_to_swap - num_double_blocks = len(flux.double_blocks) - num_single_blocks = len(flux.single_blocks) + num_double_blocks = 19 # len(flux.double_blocks) + num_single_blocks = 38 # len(flux.single_blocks) for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -618,7 +619,7 @@ def train(args): ) if is_swapping_blocks: - flux.prepare_block_swap_before_forward() + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) @@ -660,7 +661,7 @@ def train(args): with torch.no_grad(): input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] text_encoder_conds = text_encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) if args.full_fp16: text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] diff --git a/flux_train_network.py b/flux_train_network.py index 49bd270c..3e2057e9 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -57,19 +57,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time - model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") + model = flux_utils.load_flow_model( + name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model diff --git a/library/flux_models.py b/library/flux_models.py index c98d52ec..c045aef6 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) else: return self._forward(img, txt, vec, pe, txt_attention_mask) @@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe) + return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) else: return self._forward(x, vec, pe) diff --git a/library/flux_utils.py b/library/flux_utils.py index 166cd833..37166933 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -9,7 +9,7 @@ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config from library import flux_models -from library.utils import setup_logging +from library.utils import setup_logging, MemoryEfficientSafeOpen setup_logging() import logging @@ -19,32 +19,54 @@ logger = logging.getLogger(__name__) MODEL_VERSION_FLUX_V1 = "flux1" -def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux: +# temporary copy from sd3_utils TODO refactor +def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + return load_file(path, device=device) + except: + return load_file(path) # prevent device invalid Error + + +def load_flow_model( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params).to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return model -def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder: +def load_ae( + name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel: +def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: logger.info("Building CLIP") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", @@ -139,13 +161,13 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev clip = CLIPTextModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded CLIP: {info}") return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel: +def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -185,7 +207,7 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi t5xxl = T5EncoderModel._from_config(config) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_file(ckpt_path, device=str(device)) + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 737af390..b3643cbf 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -137,7 +137,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_attn_mask = data["t5_attn_mask"] if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) + t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! return [l_pooled, t5_out, txt_ids, t5_attn_mask] @@ -149,7 +149,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk + # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk ) diff --git a/library/train_util.py b/library/train_util.py index 8929c192..989758ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1104,10 +1104,6 @@ class BaseDataset(torch.utils.data.Dataset): caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() batch_size = caching_strategy.batch_size or self.batch_size - # if cache to disk, don't cache TE outputs in non-main process - if caching_strategy.cache_to_disk and not is_main_process: - return - logger.info("caching Text Encoder outputs with caching strategy.") image_infos = list(self.image_data.values()) @@ -1120,9 +1116,9 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available: # do not add to batch + if cache_available or not is_main_process: # do not add to batch continue batch.append(info) @@ -2638,7 +2634,7 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index 7de22d5a..a1620997 100644 --- a/library/utils.py +++ b/library/utils.py @@ -153,6 +153,95 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: v.contiguous().view(torch.uint8).numpy().tofile(f) +class MemoryEfficientSafeOpen: + # does not support metadata loading + def __init__(self, filename): + self.filename = filename + self.header, self.header_size = self._read_header() + self.file = open(filename, "rb") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def keys(self): + return [k for k in self.header.keys() if k != "__metadata__"] + + def get_tensor(self, key): + if key not in self.header: + raise KeyError(f"Tensor '{key}' not found in the file") + + metadata = self.header[key] + offset_start, offset_end = metadata["data_offsets"] + + if offset_start == offset_end: + tensor_bytes = None + else: + # adjust offset by header size + self.file.seek(self.header_size + 8 + offset_start) + tensor_bytes = self.file.read(offset_end - offset_start) + + return self._deserialize_tensor(tensor_bytes, metadata) + + def _read_header(self): + with open(self.filename, "rb") as f: + header_size = struct.unpack(" Date: Thu, 22 Aug 2024 19:55:31 +0900 Subject: [PATCH 135/356] Fix --debug_dataset to work. --- flux_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flux_train.py b/flux_train.py index e7d45e04..410728d4 100644 --- a/flux_train.py +++ b/flux_train.py @@ -142,6 +142,12 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) + name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return From 2d8fa3387a4adfdc2e36f2582e4ffc21864569f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:56:27 +0900 Subject: [PATCH 136/356] Fix to remove zero pad for t5xxl output --- README.md | 5 +++++ library/strategy_flux.py | 23 +++++++++++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5125c663..33b3a9a9 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 22, 2024 (update 2): +Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. + +Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. + Aug 22, 2024: Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. diff --git a/library/strategy_flux.py b/library/strategy_flux.py index b3643cbf..d52b3b8d 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -22,7 +22,7 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" class FluxTokenizeStrategy(TokenizeStrategy): - def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: self.t5xxl_max_length = t5xxl_max_length self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) @@ -120,25 +120,24 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "t5_attn_mask" not in npz: return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) l_pooled = data["l_pooled"] t5_out = data["t5_out"] txt_ids = data["txt_ids"] t5_attn_mask = data["t5_attn_mask"] - - if self.apply_t5_attn_mask: - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) # FIXME do not mask here!!! - + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask return [l_pooled, t5_out, txt_ids, t5_attn_mask] def cache_batch_outputs( @@ -149,10 +148,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is not applied when caching to disk: it is applied when loading from disk FIXME apply mask when loading - l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, not self.cache_to_disk - ) + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) if l_pooled.dtype == torch.bfloat16: l_pooled = l_pooled.float() @@ -171,6 +168,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_out_i = t5_out[i] txt_ids_i = txt_ids[i] t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask if self.cache_to_disk: np.savez( @@ -179,6 +177,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): t5_out=t5_out_i, txt_ids=txt_ids_i, t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) From b0a980844a2e02b1b1ae4cf615ae489dbf8ece67 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:57:29 +0900 Subject: [PATCH 137/356] added a script to extract LoRA --- networks/flux_extract_lora.py | 219 ++++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 networks/flux_extract_lora.py diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py new file mode 100644 index 00000000..3ee6e816 --- /dev/null +++ b/networks/flux_extract_lora.py @@ -0,0 +1,219 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from library import flux_utils, sai_model_spec, model_util, sdxl_model_util +import lora +from library.utils import MemoryEfficientSafeOpen +from library.utils import setup_logging +from networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + store_device = "cpu" + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as fo: + # filter keys + keys = [] + for key in fo.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + + with open_fn(model_tuned) as ft: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = fo.get_tensor(key) + value_t = ft.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) From bf9f798985dd75fc2dd1fbc8c8dc775c92176854 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 19:59:38 +0900 Subject: [PATCH 138/356] chore: fix typos, remove debug print --- networks/flux_extract_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py index 3ee6e816..63ab2960 100644 --- a/networks/flux_extract_lora.py +++ b/networks/flux_extract_lora.py @@ -68,10 +68,10 @@ def svd( logger.info("Using memory efficient safe_open") open_fn = lambda fn: MemoryEfficientSafeOpen(fn) - with open_fn(model_org) as fo: + with open_fn(model_org) as f_org: # filter keys keys = [] - for key in fo.keys(): + for key in f_org.keys(): if not ("single_block" in key or "double_block" in key): continue if ".bias" in key: @@ -80,11 +80,11 @@ def svd( continue keys.append(key) - with open_fn(model_tuned) as ft: + with open_fn(model_tuned) as f_tuned: for key in tqdm(keys): # get tensors and calculate difference - value_o = fo.get_tensor(key) - value_t = ft.get_tensor(key) + value_o = f_org.get_tensor(key) + value_t = f_tuned.get_tensor(key) mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) del value_o, value_t @@ -114,7 +114,7 @@ def svd( U = U.to(store_device, dtype=save_dtype).contiguous() Vh = Vh.to(store_device, dtype=save_dtype).contiguous() - print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + # print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") lora_weights[key] = (U, Vh) del mat, U, S, Vh From afb971f9c36823040eaba3c9e02fdfa0928cd4ee Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 21:33:15 +0900 Subject: [PATCH 139/356] fix SD1.5 LoRA extraction #1490 --- networks/lora.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 82b8b5b4..6f33f1a1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -815,7 +815,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh weights_sd = torch.load(file, map_location="cpu") # if keys are Diffusers based, convert to SAI based - convert_diffusers_to_sai_if_needed(weights_sd) + if is_sdxl: + convert_diffusers_to_sai_if_needed(weights_sd) # get dim/alpha mapping modules_dim = {} @@ -840,7 +841,13 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoder, + unet, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + is_sdxl=is_sdxl, ) # block lr From 81411a398eb4ce28d84cc2da8238ff013d40d62f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 22:02:29 +0900 Subject: [PATCH 140/356] speed up getting image sizes --- library/strategy_base.py | 7 ++++++- library/strategy_flux.py | 9 +++------ library/strategy_sd.py | 12 ++++-------- library/strategy_sd3.py | 9 +++------ library/train_util.py | 23 ++++++++++++++++++++++- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97e..6a01c30a 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -204,9 +204,14 @@ class LatentsCachingStrategy: def batch_size(self): return self._batch_size - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + @property + def cache_suffix(self): raise NotImplementedError + def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + return int(w), int(h) + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: raise NotImplementedError diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8d..887113ca 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -189,12 +189,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + @property + def cache_suffix(self) -> str: + return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31..af472e49 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -108,14 +108,10 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): self.suffix = ( SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX ) - - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - # does not include old npz - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + + @property + def cache_suffix(self) -> str: + return self.suffix def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: # support old .npz diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index a2281890..9fde0208 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -222,12 +222,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) - def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: - npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) - if len(npz_file) == 0: - return None, None - w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") - return int(w), int(h) + @property + def cache_suffix(self) -> str: + return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: return ( diff --git a/library/train_util.py b/library/train_util.py index 989758ad..dcc01f6f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1739,9 +1739,30 @@ class DreamBoothDataset(BaseDataset): strategy = LatentsCachingStrategy.get_strategy() if strategy is not None: logger.info("get image size from name of cache files") + + # make image path to npz path mapping + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths.sort() + npz_path_index = 0 + size_set_count = 0 for i, img_path in enumerate(tqdm(img_paths)): - w, h = strategy.get_image_size_from_disk_cache_path(img_path) + l = len(os.path.splitext(img_path)[0]) # remove extension + found = False + while npz_path_index < len(npz_paths): # until found or end of npz_paths + # npz_paths are sorted, so if npz_path > img_path, img_path is not found + if npz_paths[npz_path_index][:l] > img_path[:l]: + break + if npz_paths[npz_path_index][:l] == img_path[:l]: # found + found = True + break + npz_path_index += 1 # next npz_path + + if found: + w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + else: + w, h = None, None + if w is not None and h is not None: sizes[i] = [w, h] size_set_count += 1 From 1e8108fec9962333e4cf2a8db1dcedf657049900 Mon Sep 17 00:00:00 2001 From: liesen Date: Sat, 24 Aug 2024 01:38:17 +0300 Subject: [PATCH 141/356] Handle args.v_parameterization properly for MinSNR and changed prediction target --- sdxl_train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..14b25965 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -590,7 +590,11 @@ def train(args): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise if ( args.min_snr_gamma @@ -606,7 +610,7 @@ def train(args): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: From 2e89cd2cc634c27add7a04c21fcb6d0e16716a2b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 12:39:54 +0900 Subject: [PATCH 142/356] Fix issue with attention mask not being applied in single blocks --- README.md | 3 ++ flux_train_network.py | 4 +-- library/flux_models.py | 62 +++++++++++++++++++++--------------------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 33b3a9a9..4151bf44 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,9 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024: +Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. + Aug 22, 2024 (update 2): Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. diff --git a/flux_train_network.py b/flux_train_network.py index 3e2057e9..82f77a77 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -243,7 +243,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_upper.to("cpu") clean_memory_on_device(self.target_device) self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) @@ -352,7 +352,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) diff --git a/library/flux_models.py b/library/flux_models.py index c045aef6..b5726c29 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -752,18 +752,6 @@ class DoubleStreamBlock(nn.Module): else: return self._forward(img, txt, vec, pe, txt_attention_mask) - # def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint( - # create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT - # ) - # else: - # return self._forward(img, txt, vec, pe) - class SingleStreamBlock(nn.Module): """ @@ -809,7 +797,7 @@ class SingleStreamBlock(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -817,16 +805,35 @@ class SingleStreamBlock(nn.Module): q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + # compute attention - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, x, vec, pe, use_reentrant=False) + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) # cpu offload checkpointing @@ -838,19 +845,11 @@ class SingleStreamBlock(nn.Module): return custom_forward - return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False) + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) else: - return self._forward(x, vec, pe) - - # def forward(self, x: Tensor, vec: Tensor, pe: Tensor): - # if self.training and self.gradient_checkpointing: - # def create_custom_forward(func): - # def custom_forward(*inputs): - # return func(*inputs) - # return custom_forward - # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT) - # else: - # return self._forward(x, vec, pe) + return self._forward(x, vec, pe, txt_attention_mask) class LastLayer(nn.Module): @@ -1053,7 +1052,7 @@ class Flux(nn.Module): if not self.single_blocks_to_swap: for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning for block_idx in range(self.single_blocks_to_swap): @@ -1075,7 +1074,7 @@ class Flux(nn.Module): block.to(self.device) # move to cuda # print(f"Moved single block {block_idx} to cuda.") - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) if moving: self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) @@ -1250,10 +1249,11 @@ class FluxLower(nn.Module): txt: Tensor, vec: Tensor | None = None, pe: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: img = torch.cat((txt, img), 1) for block in self.single_blocks: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) From cf689e7aa697877a0eee58622035ab702ce59d3e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:35:43 +0900 Subject: [PATCH 143/356] feat: Add option to split projection layers and apply LoRA --- README.md | 14 ++ networks/check_lora_weights.py | 2 +- networks/convert_flux_lora.py | 51 ++++-- networks/lora_flux.py | 322 +++++++++++++++++++++++++++------ 4 files changed, 323 insertions(+), 66 deletions(-) diff --git a/README.md b/README.md index 4151bf44..7d326a86 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 24, 2024 (update 2): + +__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + +The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. + +This implementation is experimental, so it may be deprecated or changed in the future. + +The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. + +Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + Aug 24, 2024: Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c9..b5b5e61a 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index e9743534..bd4c1cf7 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -266,11 +266,12 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): if sds_key + ".lora_down.weight" not in sds_sd: return down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] # scale weight by alpha and dim - rank = down_weight.shape[0] alpha = sds_sd.pop(sds_key + ".alpha") - scale = alpha / rank + scale = alpha / sd_lora_rank # calculate scale_down and scale_up scale_down = scale @@ -279,23 +280,49 @@ def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): scale_down *= 2 scale_up /= 2 - ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] - ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] - - num_splits = len(ait_keys) - up_weight = sds_sd.pop(sds_key + ".lora_up.weight") - - # down_weight is copied to each split - ait_sd.update({k: down_weight * scale_down for k in ait_down_keys}) + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up # calculate dims if not provided + num_splits = len(ait_keys) if dims is None: dims = [up_weight.shape[0] // num_splits] * num_splits else: assert sum(dims) == up_weight.shape[0] - # up_weight is split to each split - ait_sd.update({k: v * scale_up for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] def convert_sd_scripts_to_ai_toolkit(sds_sd): diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 4da33542..efc7847e 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -39,6 +39,7 @@ class LoRAModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + split_dims: Optional[List[int]] = None, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() @@ -52,16 +53,34 @@ class LoRAModule(torch.nn.Module): out_dim = org_module.out_features self.lora_dim = lora_dim + self.split_dims = split_dims - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -70,9 +89,6 @@ class LoRAModule(torch.nn.Module): self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout @@ -92,30 +108,56 @@ class LoRAModule(torch.nn.Module): if torch.rand(1) < self.module_dropout: return org_forwarded - lx = self.lora_down(x) + if self.split_dims is None: + lx = self.lora_down(x) - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale else: - scale = self.scale + lxs = [lora_down(x) for lora_down in self.lora_down] - lx = self.lora_up(lx) + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] - return org_forwarded + lx * self.multiplier * scale + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale class LoRAInfModule(LoRAModule): @@ -152,31 +194,50 @@ class LoRAInfModule(LoRAModule): if device is None: device = org_device - # get up/down weight - up_weight = sd["lora_up.weight"].to(torch.float).to(device) - down_weight = sd["lora_down.weight"].to(torch.float).to(device) + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) - # merge weight - if len(weight.size()) == 2: - # linear - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + self.multiplier * conved * self.scale + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) - # set weight to org_module - org_sd["weight"] = weight.to(dtype) - self.org_module.load_state_dict(org_sd) + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) # 復元できるマージのため、このモジュールのweightを返す def get_weight(self, multiplier=None): @@ -211,7 +272,14 @@ class LoRAInfModule(LoRAModule): def default_forward(self, x): # logger.info(f"default_forward {self.lora_name} {x.size()}") - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale def forward(self, x): if not self.enabled: @@ -257,6 +325,11 @@ def create_network( if train_blocks is not None: assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -270,6 +343,7 @@ def create_network( conv_lora_dim=conv_dim, conv_alpha=conv_alpha, train_blocks=train_blocks, + split_qkv=split_qkv, varbose=True, ) @@ -311,10 +385,34 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoders, flux, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, ) return network, weights_sd @@ -344,6 +442,7 @@ class LoRANetwork(torch.nn.Module): modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, + split_qkv: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -357,6 +456,7 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -373,6 +473,8 @@ class LoRANetwork(torch.nn.Module): logger.info( f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") # create module instances def create_modules( @@ -420,6 +522,14 @@ class LoRANetwork(torch.nn.Module): skipped.append(lora_name) continue + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + lora = module_class( lora_name, child_module, @@ -429,6 +539,7 @@ class LoRANetwork(torch.nn.Module): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + split_dims=split_dims, ) loras.append(lora) return loras, skipped @@ -492,6 +603,111 @@ class LoRANetwork(torch.nn.Module): info = self.load_state_dict(weights_sd, False) return info + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to splitted qkv weight + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") From 5639c2adc0085e2e995bb3eee5a278aace397e7a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 16:37:49 +0900 Subject: [PATCH 144/356] fix typo --- networks/lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index efc7847e..07a80f0b 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -604,7 +604,7 @@ class LoRANetwork(torch.nn.Module): return info def load_state_dict(self, state_dict, strict=True): - # override to convert original weight to splitted qkv weight + # override to convert original weight to split qkv if not self.split_qkv: return super().load_state_dict(state_dict, strict) From d5c076cf9007f86f6dd1b9ecdfc5531336774b2f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 21:21:39 +0900 Subject: [PATCH 145/356] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 946df58f..81a54937 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. From 72287d39c76176c0e1c16e8da4f5ddc6f94ea7d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 25 Aug 2024 16:01:24 +0900 Subject: [PATCH 146/356] feat: Add `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training --- README.md | 4 ++++ library/flux_train_utils.py | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 282f3b3b..562dcdb2 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 25, 2024: +Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. +Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` + Aug 24, 2024 (update 2): __Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d3f80d7..75f70a54 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps( t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: t = torch.rand((bsz,), device=device) + timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid"], + choices=["sigma", "uniform", "sigmoid", "shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", ) parser.add_argument( "--sigmoid_scale", From 0087a46e14c8e568982cbe3a5d9b9c561b175abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 19:59:40 +0900 Subject: [PATCH 147/356] FLUX.1 LoRA supports CLIP-L --- README.md | 8 ++++ flux_train_network.py | 40 +++++++++++++----- library/flux_train_utils.py | 8 ++-- library/strategy_flux.py | 3 +- networks/lora_flux.py | 4 +- train_network.py | 81 ++++++++++++++++++++++++------------- 6 files changed, 101 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 562dcdb2..1203b5eb 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024: + +- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. + - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. +- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. + +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). + Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` diff --git a/flux_train_network.py b/flux_train_network.py index 82f77a77..1a40de61 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -40,9 +40,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - assert ( - args.network_train_unet_only or not args.cache_text_encoder_outputs - ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + # assert ( + # args.network_train_unet_only or not args.cache_text_encoder_outputs + # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if not args.network_train_unet_only: + logger.info( + "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" + ) if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -137,12 +141,25 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): - return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + if args.cache_text_encoder_outputs: + if self.is_train_text_encoder(args): + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return text_encoders # ignored + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [True, False] if self.is_train_text_encoder(args) else [False, False] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + args.cache_text_encoder_outputs_to_disk, + None, + False, + is_partial=self.is_train_text_encoder(args), + apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None @@ -190,9 +207,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): accelerator.wait_for_everyone() # move back to cpu - logger.info("move text encoders back to cpu") - text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU - text_encoders[1].to("cpu") # , dtype=torch.float32) + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -297,7 +316,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - t.requires_grad_(True) + if t.dtype.is_floating_point: + t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) @@ -384,7 +404,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def setup_parser() -> argparse.ArgumentParser: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 75f70a54..a8e94ac0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -58,7 +58,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -66,7 +66,8 @@ def sample_images( # unwrap unet and text_encoder(s) flux = accelerator.unwrap_model(flux) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = load_prompts(args.sample_prompts) @@ -134,7 +135,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, flux: flux_models.Flux, - text_encoders: List[CLIPTextModel], + text_encoders: Optional[List[CLIPTextModel]], ae: flux_models.AutoEncoder, save_dir, prompt_dict, @@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps( elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8d..5d083913 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - clip_l, t5xxl = models + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None @@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): else: t5_out = None txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 07a80f0b..fcb56a46 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -401,7 +401,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( # single_qkv_rank is not None and single_qkv_rank != rank # ) - split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index cab0ec52..048c7e7b 100644 --- a/train_network.py +++ b/train_network.py @@ -127,8 +127,15 @@ class NetworkTrainer: return None def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + """ return text_encoders + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + def is_train_text_encoder(self, args): return not args.network_train_unet_only @@ -136,11 +143,6 @@ class NetworkTrainer: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -313,7 +315,7 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -437,8 +439,10 @@ class NetworkTrainer: if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect @@ -522,14 +526,17 @@ class NetworkTrainer: unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory @@ -546,19 +553,18 @@ class NetworkTrainer: t_enc.to(dtype=te_weight_dtype) if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,11 +577,14 @@ class NetworkTrainer: else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set @@ -587,11 +596,11 @@ class NetworkTrainer: if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: + if frag: t_enc.text_model.embeddings.requires_grad_(True) else: @@ -736,6 +745,7 @@ class NetworkTrainer: "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, "ss_fp8_base": args.fp8_base, + "ss_fp8_base_unet": args.fp8_base_unet, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1004,6 +1014,7 @@ class NetworkTrainer: for t_enc in text_encoders: del t_enc text_encoders = [] + text_encoder = None # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) @@ -1018,7 +1029,7 @@ class NetworkTrainer: # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1073,12 +1084,17 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - else: + if ( + text_encoder_conds is None + or len(text_encoder_conds) == 0 + or text_encoder_conds[0] is None + or train_text_encoder + ): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: # SD only - text_encoder_conds = get_weighted_text_embeddings( + encoded_text_encoder_conds = get_weighted_text_embeddings( tokenizers[0], text_encoder, batch["captions"], @@ -1088,13 +1104,18 @@ class NetworkTrainer: ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( @@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" From 3be712e3e011b0378fad389641cec0c1869555ab Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:40:02 +0900 Subject: [PATCH 148/356] feat: Update direct loading fp8 ckpt for LoRA training --- README.md | 7 +++- flux_minimal_inference.py | 27 +----------- flux_train_network.py | 16 +++++++- library/flux_utils.py | 12 ++++-- library/utils.py | 62 +++++++++++++++++++++++++++- networks/flux_merge_lora.py | 82 ++++++++++++++++++++++++++----------- 6 files changed, 151 insertions(+), 55 deletions(-) diff --git a/README.md b/README.md index 1203b5eb..0108ada5 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,18 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024 (update 2): +In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. + +In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. + Aug 27, 2024: - FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. - `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 5b8aa250..56c1b198 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -10,7 +10,6 @@ import einops import numpy as np import torch -from safetensors.torch import safe_open, load_file from tqdm import tqdm from PIL import Image import accelerate @@ -21,7 +20,7 @@ from library.device_utils import init_ipex, get_preferred_device init_ipex() -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype setup_logging() import logging @@ -288,28 +287,6 @@ if __name__ == "__main__": name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way is_schnell = name == "schnell" - def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: - if s is None: - return default_dtype - if s in ["bf16", "bfloat16"]: - return torch.bfloat16 - elif s in ["fp16", "float16"]: - return torch.float16 - elif s in ["fp32", "float32"]: - return torch.float32 - elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: - return torch.float8_e4m3fn - elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: - return torch.float8_e4m3fnuz - elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: - return torch.float8_e5m2 - elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: - return torch.float8_e5m2fnuz - elif s in ["fp8", "float8"]: - return torch.float8_e4m3fn # default fp8 - else: - raise ValueError(f"Unsupported dtype: {s}") - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -348,7 +325,7 @@ if __name__ == "__main__": encoding_strategy = strategy_flux.FluxTextEncodingStrategy() # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device) + model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train_network.py b/flux_train_network.py index 1a40de61..4a63c2de 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -29,6 +29,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning( "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" @@ -61,9 +64,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): name = self.get_flux_model_name(args) # if we load to cpu, flux.to(fp8) takes a long time + if args.fp8_base: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) diff --git a/library/flux_utils.py b/library/flux_utils.py index 37166933..68083616 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,5 +1,5 @@ import json -from typing import Union +from typing import Optional, Union import einops import torch @@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1" # temporary copy from sd3_utils TODO refactor -def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32): +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +): if disable_mmap: # return safetensors.torch.load(open(path, "rb").read()) # use experimental loader @@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: def load_flow_model( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.Flux: logger.info(f"Building Flux model {name}") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params).to(dtype) + model = flux_models.Flux(flux_models.configs[name].params) + if dtype is not None: + model = model.to(dtype) # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") diff --git a/library/utils.py b/library/utils.py index a1620997..d355cb10 100644 --- a/library/utils.py +++ b/library/utils.py @@ -82,6 +82,66 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): """ memory efficient save file @@ -198,7 +258,7 @@ class MemoryEfficientSafeOpen: if tensor_bytes is None: byte_tensor = torch.empty(0, dtype=torch.uint8) else: - tensor_bytes = bytearray(tensor_bytes) # make it writable + tensor_bytes = bytearray(tensor_bytes) # make it writable byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) # process float8 types diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index d5e82920..2e0d4c29 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -8,7 +8,7 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm -from library.utils import setup_logging +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() import logging @@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): +def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") - for key in list(state_dict.keys()): + for key in tqdm(list(state_dict.keys())): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") - save_file(state_dict, file_name, metadata=metadata) + if mem_eff_save: + mem_eff_save_file(state_dict, file_name, metadata=metadata) + else: + save_file(state_dict, file_name, metadata=metadata) -def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): # create module map without loading state_dict logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} @@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") lora_name_to_module_key[lora_name] = key - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU @@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati return flux_state_dict -def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype): +def merge_to_flux_model_diffusers( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): logger.info(f"loading keys from FLUX.1 model: {flux_model}") - flux_state_dict = load_file(flux_model, device=loading_device) + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) def create_key_map(n_double_layers, n_single_layers): key_map = {} @@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): + if args.models is None: + args.models = [] + if args.ratios is None: + args.ratios = [] + assert len(args.models) == len( args.ratios ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - merge_dtype = str_to_dtype(args.precision) save_dtype = str_to_dtype(args.save_precision) if save_dtype is None: @@ -500,11 +516,25 @@ def merge(args): if args.flux_model is not None: if not args.diffusers: state_dict = merge_to_flux_model( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) else: state_dict = merge_to_flux_model_diffusers( - args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, ) if args.no_metadata: @@ -517,7 +547,7 @@ def merge(args): ) logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata) + save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) @@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser: "--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + help="precision in saving, same to merging if omitted. supported types: " + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", ) parser.add_argument( "--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( @@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) parser.add_argument( "--loading_device", type=str, From a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 21:44:10 +0900 Subject: [PATCH 149/356] update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0108ada5..7b1d9cc6 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: Aug 27, 2024 (update 2): In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. -In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed. +In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. Aug 27, 2024: From 6c0e8a5a1740dbd50a0a45ec1f08983877605cd7 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 14:50:29 +0800 Subject: [PATCH 150/356] make guidance_scale keep float in args --- flux_train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index 4a63c2de..354a8c6f 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -324,7 +324,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # ensure the hidden state will require grad if args.gradient_checkpointing: From a0cfb0894c4be4ea27412e4c12ed13f68b57094b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 21:20:33 +0900 Subject: [PATCH 151/356] Cleaned up README --- README.md | 299 +++++++++++++++++++++++++++--------------------------- 1 file changed, 152 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index 7b1d9cc6..a73eead0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 LoRA training (WIP) +## FLUX.1 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,127 +9,24 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` -Aug 27, 2024 (update 2): -In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`. - -In `flux_merge_lora.py`, you can now specify `fp8` for the save precision (see `--help` for details). Also, if you do not specify the merge model, only the dtype conversion will be performed. - -Aug 27, 2024: - -- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. - - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. -- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. - -- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled. - -Aug 25, 2024: -Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. -Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` - -Aug 24, 2024 (update 2): - -__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). - -The number of parameters may increase slightly, so the expressiveness may increase, but the training time may be longer. No detailed verification has been done. - -This implementation is experimental, so it may be deprecated or changed in the future. - -The .safetensors file of the trained model is compatible with the normal LoRA model of sd-scripts, so it should be usable in inference environments such as ComfyUI as it is. Also, converting it to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. It should be no problem to convert it if you use it in the inference environment. - -Technical details: In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. - -The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. - -Aug 24, 2024: -Fixed an issue where the attention mask was not applied in single blocks when `--apply_t5_attn_mask` was specified. - -Aug 22, 2024 (update 2): -Fixed a bug that the embedding was zero-padded when `--apply_t5_attn_mask` option was applied. Also, the cache file for text encoder outputs now records whether the mask is applied or not. Please note that the cache file will be recreated when switching the `--apply_t5_attn_mask` option. - -Added a script to extract LoRA from the difference between the two models of FLUX.1. Use `networks/flux_extract_lora.py`. See `--help` for details. Normally, more than 50GB of memory is required, but specifying the `--mem_eff_safe_open` option significantly reduces memory usage. However, this option is a custom implementation, so unexpected problems may occur. Please always check if the model is loaded correctly. - -Aug 22, 2024: -Fixed a bug in multi-GPU training. It should work with fine-tuning and LoRA training. `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. - -`--disable_mmap_load_safetensors` option now works in `flux_train.py`. It speeds up model loading during training in WSL2. It is also effective in reducing memory usage when loading models during multi-GPU training. Please always check if the model is loaded correctly, as it uses a custom implementation of safetensors loading. - - -Aug 21, 2024 (update 3): -- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__ -- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is -based on the code provided by 2kpr. Thank you so much! - - With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified. - - Please note that `--fused_backward_pass` is only supported with Adafactor. -- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes. -- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`. - -Aug 21, 2024 (update 2): -Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool. - -Added a script `convert_flux_lora.py` to convert LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). See `--help` for details. BFL-based LoRA has a large module, so converting it to Diffusers format may reduce temporary memory usage in the inference environment. Note that re-conversion will increase the size of LoRA. - - -Aug 21, 2024: -The specification of `--apply_t5_attn_mask` has been changed. Previously, the T5 output was zero-padded, but now, two steps are taken: "1. Apply mask when encoding T5" and "2. Apply mask in the attention of Double Block". Fine tuning, LoRA training, and inference in `flux_mini_inference.py` have been changed. - -Aug 20, 2024 (update 3): -__Experimental__ The multi-resolution training is now supported with caching latents to disk. - -The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). - -See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -Aug 20, 2024 (update 2): -`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! - -Aug 20, 2024: -FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). - -The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. - -We will support multi-resolution caching to disk in the near future. - -Aug 19, 2024: -In `flux_train.py`, the memory consumption during model saving is reduced when `--save_precision` is set to the same value as `--mixed_precision` (about 22GB). Please set the same value unless there is a reason. - -An experimental option `--mem_eff_save` is also added. When specified, it can further reduce memory consumption (about 22GB), but since it is a custom implementation, unexpected problems may occur. We do not recommend using it unless you are familiar with the code. - -Aug 18, 2024: -Memory-efficient training based on 2kpr's implementation is implemented in `flux_train.py`. Thanks to 2kpr! See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - -Aug 17, 2024: -Added a script `flux_train.py` to train FLUX.1. The script is experimental and not an optimized version. It needs >28GB VRAM for training. - -Aug 16, 2024: - -Added a script `networks/flux_merge_lora.py` to merge LoRA into FLUX.1 checkpoint. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. - -Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. - -Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. - -Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. - -Aug 13, 2024: - -__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`. - -This argument is available even if `--split_mode` is not specified. - -__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments. - -This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default. - -Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. - -Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. - +- [FLUX.1 LoRA training](#flux1-lora-training) + - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) + - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 fine-tuning](#flux1-fine-tuning) + - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) +- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) +- [Convert FLUX LoRA](#convert-flux-lora) +- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) +- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) ### FLUX.1 LoRA training -We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. + +FLUX.1 model, CLIP-L, and T5XXL models are recommended to be in bf16/fp16 format. If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. + +Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py @@ -137,46 +34,107 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 ---network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml ---output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigmoid ---model_prediction_type raw --guidance_scale 1.0 --loss_type l2 +--output_dir path/to/output/dir --output_name flux-lora-name +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ``` (The command is multi-line for readability. Please combine it into one line.) The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -LoRAs for Text Encoders are not tested yet. - -We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows: - -- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux). -- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. -- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3). -- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). - -`--loss_type` may be useful for FLUX.1 training. The default is `l2`. - -In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. - -additional note (Aug 11): A quick check shows that the settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). This seems to be a good starting point. Thanks to Ostris for the great work! - -Other settings may work better, so please try different settings. - We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. +#### Key Options for FLUX.1 LoRA training + +There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. + +- `--timestep_sampling` is the method to sample timesteps (0-1): + - `sigma`: sigma-based, same as SD3 + - `uniform`: uniform random + - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. + - `shift`: shifts the value of sigmoid of normal distribution random number +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. + - This option is effective even when`--timestep_sampling shift` is specified. + - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. +- `--model_prediction_type` is how to interpret and process the model prediction: + - `raw`: use as is, same as x-flux + - `additive`: add to noisy input + - `sigma_scaled`: apply sigma scaling, same as SD3 +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). + +The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. + +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ + +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. + +The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). + +Other settings may work better, so please try different settings. + +Other options are described below. + +#### Distribution of timesteps + +`--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. + +The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): + +The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: + +The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): + +#### Key Features for FLUX.1 LoRA training + +1. CLIP-L LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L LoRA. + - Remove `--network_train_unet_only` from your command. + - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - The trained LoRA can be used with ComfyUI. + - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. + - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them. + - Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + - May increase expressiveness but also training time. + - The trained model is compatible with normal LoRA models in sd-scripts and can be used in environments like ComfyUI. + - Converting to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. + +4. T5 Attention Mask Application: + - T5 attention mask is applied when `--apply_t5_attn_mask` is specified. + - Now applies mask when encoding T5 and in the attention of Double and Single Blocks + - Affects fine-tuning, LoRA training, and inference in `flux_minimal_inference.py`. + +5. Multi-resolution Training Support: + - FLUX.1 now supports multi-resolution training, even with caching latents to disk. + + +Technical details of Q/K/V split: + +In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + +### Inference for FLUX.1 with LoRA model + The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` @@ -185,6 +143,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safete ### FLUX.1 fine-tuning +The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -195,15 +155,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name --learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" ---timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 +--lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 ``` +(The command is multi-line for readability. Please combine it into one line.) -(Combine the command into one line.) - -Sample image generation during training is not tested yet. - -Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). @@ -223,6 +181,53 @@ Swap 6 double blocks and use cpu offload checkpointing may be a good starting po The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### Key Features for FLUX.1 fine-tuning + +1. Sample Image Generation: + - Sample image generation during training is now supported. + - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. + - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. + - Note: It will be very slow when `--split_mode` is specified. + +2. Experimental Memory-Efficient Saving: + - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). + - This is a custom implementation and may cause unexpected issues. Use with caution. + +3. T5XXL Token Length Control: + - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. + - Default is 512 in dev and 256 in schnell models. + +4. Multi-GPU Training Support: + - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +5. Disable mmap Load for Safetensors: + - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. + - Speeds up model loading during training in WSL2. + - Effective in reducing memory usage when loading models during multi-GPU training. + + +### Extract LoRA from FLUX.1 Models + +Script: `networks/flux_extract_lora.py` + +Extracts LoRA from the difference between two FLUX.1 models. + +Offers memory-efficient option with `--mem_eff_safe_open`. + +CLIP-L LoRA is not supported. + +### Convert FLUX LoRA + +Script: `convert_flux_lora.py` + +Converts LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). + +If you use LoRA in the inference environment, converting it to AI-toolkit format may reduce temporary memory usage. + +Note that re-conversion will increase the size of LoRA. + +CLIP-L LoRA is not supported. + ### Merge LoRA to FLUX.1 checkpoint `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ From daa6ad516581872aa6acaa15c0d24aad4f998838 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:25:30 +0900 Subject: [PATCH 152/356] Update README.md --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a73eead0..6e2ae337 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,9 @@ There are many unknown points in FLUX.1 training, so some settings can be specif The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. -~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. ~~ +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.~~ -In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` with `--loss_type l2` seems to work better than other settings. +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type) seems to work better. The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). @@ -92,10 +92,13 @@ Other options are described below. `--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): +![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6) -The difference between `--timestep_sampling uniform` and `--timestep_sampling sigma`: +The difference between `--timestep_sampling sigmoid` and `--timestep_sampling uniform` (when `--timestep_sampling sigmoid` or `uniform` is specified, `--discrete_flow_shift` is ignored): +![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad) The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): +![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc) #### Key Features for FLUX.1 LoRA training From 8ecf0fc4bfd1b03cfc6fd4055af0b3363f5d1f38 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:10:57 +0900 Subject: [PATCH 153/356] Refactor code to ensure args.guidance_scale is always a float #1525 --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index 410728d4..32a36f03 100644 --- a/flux_train.py +++ b/flux_train.py @@ -688,8 +688,8 @@ def train(args): packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - # get guidance - guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device) + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds From 8fdfd8c857a88aaa78ac9c2488432ef8115982f2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 29 Aug 2024 22:26:29 +0900 Subject: [PATCH 154/356] Update safetensors to version 0.4.4 in requirements.txt #1524 --- README.md | 7 +++++++ requirements.txt | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6e2ae337..30264e73 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,13 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +### Recent Updates + +Aug 29, 2024: +Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. + +### Contents + - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) diff --git a/requirements.txt b/requirements.txt index 4ee19b3e..4c1bc392 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard -safetensors==0.4.2 +safetensors==0.4.4 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 From 34f2315047f8d5b89b7a8a6093bb56679bff13c3 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 22:33:37 +0800 Subject: [PATCH 155/356] fix: text_encoder_conds referenced before assignment --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 048c7e7b..628c421c 100644 --- a/train_network.py +++ b/train_network.py @@ -1081,12 +1081,12 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) + text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if ( - text_encoder_conds is None - or len(text_encoder_conds) == 0 + len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder ): From 35882f8d5bbd076a97622cf6193c988621481803 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 23:03:43 +0800 Subject: [PATCH 156/356] fix --- train_network.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 628c421c..4204bce3 100644 --- a/train_network.py +++ b/train_network.py @@ -1112,10 +1112,14 @@ class NetworkTrainer: if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( From 2a3aefb4e44dce1f189677d0a996ba0244633956 Mon Sep 17 00:00:00 2001 From: Nando Metzger <42088121+nandometzger@users.noreply.github.com> Date: Fri, 30 Aug 2024 08:15:05 +0200 Subject: [PATCH 157/356] Update train_util.py, bug fix --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..0fec565d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1489,7 +1489,7 @@ class DreamBoothDataset(BaseDataset): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): logger.warning(f"not directory: {subset.image_dir}") - return [], [] + return [], [], [] info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info From 3a6154b7b0dbcae82d24adacf5a76f75288b98f4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 31 Aug 2024 06:21:16 +0000 Subject: [PATCH 158/356] Bump opencv-python from 4.7.0.68 to 4.8.1.78 Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78. - [Release notes](https://github.com/opencv/opencv-python/releases) - [Commits](https://github.com/opencv/opencv-python/commits) --- updated-dependencies: - dependency-name: opencv-python dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e99775b8..977c5cd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 +opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 bitsandbytes==0.43.0 From 25c9040f4fbbcbddc0297895369337846152fea4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 31 Aug 2024 03:05:19 +0800 Subject: [PATCH 159/356] Update flux_train_utils.py --- library/flux_train_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index a8e94ac0..735bcced 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz = latents.shape[0] + bsz, _, H, W = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -392,6 +392,16 @@ def get_noisy_model_input_and_timesteps( timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + timesteps = time_shift(mu, 1.0, timesteps) + t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 noisy_model_input = (1 - t) * latents + t * noise @@ -571,7 +581,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--timestep_sampling", - choices=["sigma", "uniform", "sigmoid", "shift"], + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", From 1bcf8d600bfb9f4314a41a12a5e7b272a17ceaed Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:33:04 +0000 Subject: [PATCH 160/356] Bump crate-ci/typos from 1.19.0 to 1.24.3 Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3. - [Release notes](https://github.com/crate-ci/typos/releases) - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md) - [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3) --- updated-dependencies: - dependency-name: crate-ci/typos dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483..0149dcdd 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.24.3 From ef510b3cb94427d72df681389e1214251813b1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 1 Sep 2024 17:41:01 +0800 Subject: [PATCH 161/356] Sd3 freeze x_block (#1417) * Update sd3_train.py * add freeze block lr * Update train_util.py * update --- library/train_util.py | 21 +++++++++++++++++++++ sd3_train.py | 9 ++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 989758ad..74aae0a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,6 +3246,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="num_last_block_to_freeze", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5758,6 +5764,21 @@ def sample_image_inference( pass +def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): + + filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] + print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) + + print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + # endregion diff --git a/sd3_train.py b/sd3_train.py index 3b6c8a11..ce9500b0 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -368,12 +368,19 @@ def train(args): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + training_models = [] params_to_optimize = [] # if train_unet: training_models.append(mmdit) # if block_lrs is None: - params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate}) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) # else: # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) From 92e7600cc2fea604321004f260e7db76c764f388 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 18:57:07 +0900 Subject: [PATCH 162/356] Move freeze_blocks to sd3_train because it's only for sd3 --- README.md | 3 +++ library/train_util.py | 21 --------------------- sd3_train.py | 22 ++++++++++++++++++++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 30264e73..d9636719 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,9 @@ resolution = [512, 512] SD3 training is done with `sd3_train.py`. +__Sep 1, 2024__: +- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! + __Jul 27, 2024__: - Latents and text encoder outputs caching mechanism is refactored significantly. - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. diff --git a/library/train_util.py b/library/train_util.py index 74aae0a7..989758ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3246,12 +3246,6 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", ) - parser.add_argument( - "--num_last_block_to_freeze", - type=int, - default=None, - help="num_last_block_to_freeze", - ) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -5764,21 +5758,6 @@ def sample_image_inference( pass -def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"): - - filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name] - print(f"filtered_blocks: {len(filtered_blocks)}") - - num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze) - - print(f"freeze_blocks: {num_blocks_to_freeze}") - - start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) - - for i in range(start_freezing_from, len(filtered_blocks)): - _, param = filtered_blocks[i] - param.requires_grad = False - # endregion diff --git a/sd3_train.py b/sd3_train.py index ce9500b0..87011b21 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -373,7 +373,20 @@ def train(args): mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared if args.num_last_block_to_freeze: - train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze) + # freeze last n blocks of MM-DIT + block_name = "x_block" + filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name] + accelerator.print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze) + + accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False training_models = [] params_to_optimize = [] @@ -1033,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", ) - parser.add_argument( "--skip_latents_validity_check", action="store_true", help="skip latents validity check / latentsの正当性チェックをスキップする", ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", + ) return parser From 4f6d915d15262447b1049a78a55678b2825784a3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Sep 2024 19:12:29 +0900 Subject: [PATCH 163/356] update help and README --- README.md | 5 +++++ library/flux_train_utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d9636719..331951ef 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 1, 2024: +- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! + - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. + Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. @@ -73,6 +77,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `uniform`: uniform random - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. - `shift`: shifts the value of sigmoid of normal distribution random number + - `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified. - `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. - This option is effective even when`--timestep_sampling shift` is specified. - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 735bcced..9dad4baa 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, _, H, W = latents.shape + bsz, _, h, w = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -399,7 +399,7 @@ def get_noisy_model_input_and_timesteps( logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() - mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) t = timesteps.view(-1, 1, 1, 1) @@ -583,8 +583,8 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], default="sigma", - help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." - " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", ) parser.add_argument( "--sigmoid_scale", From 6abacf04da756808ffca567f6660445ecdf478bd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 2 Sep 2024 13:05:26 +0900 Subject: [PATCH 164/356] update README --- README.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 331951ef..5dd916aa 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,7 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. +`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. `--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. @@ -198,24 +198,32 @@ The learning rate and the number of epochs are not optimized yet. Please adjust #### Key Features for FLUX.1 fine-tuning -1. Sample Image Generation: +1. Technical details of double/single block swap: + - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. + - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. + - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. + - Since the transfer between CPU and GPU takes time, the training will be slower. + - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. + - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + +2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - Note: It will be very slow when `--split_mode` is specified. -2. Experimental Memory-Efficient Saving: +3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). - This is a custom implementation and may cause unexpected issues. Use with caution. -3. T5XXL Token Length Control: +4. T5XXL Token Length Control: - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. - Default is 512 in dev and 256 in schnell models. -4. Multi-GPU Training Support: +5. Multi-GPU Training Support: - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. -5. Disable mmap Load for Safetensors: +6. Disable mmap Load for Safetensors: - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. - Speeds up model loading during training in WSL2. - Effective in reducing memory usage when loading models during multi-GPU training. From b65ae9b439e4324359014d6d720aa01def3a19fc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:33:17 +0900 Subject: [PATCH 165/356] T5XXL LoRA training, fp8 T5XXL support --- README.md | 45 +++++++++++---- flux_train_network.py | 112 +++++++++++++++++++++++++++++------- library/flux_train_utils.py | 23 ++++++-- library/flux_utils.py | 9 ++- library/strategy_flux.py | 13 ++++- networks/lora_flux.py | 39 ++++++++++--- train_network.py | 48 ++++++++++------ 7 files changed, 222 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 5dd916aa..84065570 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,11 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 4, 2024: +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. +- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. + Sep 1, 2024: - `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. @@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs. ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors ---ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base @@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI. There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. +- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`). + - `--timestep_sampling` is the method to sample timesteps (0-1): - `sigma`: sigma-based, same as SD3 - `uniform`: uniform random @@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times #### Key Features for FLUX.1 LoRA training -1. CLIP-L LoRA Support: - - FLUX.1 LoRA training now supports CLIP-L LoRA. +1. CLIP-L and T5XXL LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training. - Remove `--network_train_unet_only` from your command. - - T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. + - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - The trained LoRA can be used with ComfyUI. - - Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. + - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |FLUX.1|`--network_train_unet_only`|-|o| + |FLUX.1 + CLIP-L|-|-|o (*2)| + |FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L LoRA training. + - *3: Not tested yet. 2. Experimental FP8/FP16 mixed training: - - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L. - - FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16. + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL. + - FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16. - When specifying this option, the `--fp8_base` option is automatically enabled. 3. Split Q/K/V Projection Layers (Experimental): @@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. ``` -python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 +python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` ### FLUX.1 fine-tuning @@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP ``` accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py ---pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name @@ -256,7 +279,7 @@ CLIP-L LoRA is not supported. `networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ ``` -python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu ``` You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. diff --git a/flux_train_network.py b/flux_train_network.py index 354a8c6f..2fc0f323 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - # assert ( - # args.network_train_unet_only or not args.cache_text_encoder_outputs - # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - if not args.network_train_unet_only: - logger.info( - "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" - ) + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # currently offload to cpu for some models name = self.get_flux_model_name(args) - # if we load to cpu, flux.to(fp8) takes a long time - if args.fp8_base: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future model = flux_utils.load_flow_model( name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) @@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) @@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def get_text_encoding_strategy(self, args): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: - if self.is_train_text_encoder(args): + if self.train_clip_l and not self.train_t5xxl: return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached else: - return text_encoders # ignored + return None # no text encoders are needed for encoding because both are cached else: return text_encoders # both CLIP-L and T5XXL are needed for encoding def get_text_encoders_train_flags(self, args, text_encoders): - return [True, False] if self.is_train_text_encoder(args) else [False, False] + return [self.train_clip_l, self.train_t5xxl] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, None, False, - is_partial=self.is_train_text_encoder(args), + is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: @@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # When TE is not be trained, it will not be prepared so we need to use explicit autocast logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) @@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # return noise_pred def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + if not args.split_mode: flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) return @@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) clean_memory_on_device(accelerator.device) flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) @@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 9dad4baa..0b5d4d90 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -85,7 +85,7 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -187,14 +187,27 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_conds = [] if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) - l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds # sample image weight_dtype = ae.dtype # TOFO give dtype as argument diff --git a/library/flux_utils.py b/library/flux_utils.py index 68083616..7b0a41a8 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev return clip -def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel: +def load_t5xxl( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> T5EncoderModel: T5_CONFIG_JSON = """ { "architectures": [ @@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi return t5xxl +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 5d083913..6c9ef5e4 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -5,8 +5,7 @@ import torch import numpy as np from transformers import CLIPTokenizer, T5TokenizerFast -from library import sd3_utils, train_util -from library import sd3_models +from library import flux_utils, train_util from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy from library.utils import setup_logging @@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) self.apply_t5_attn_mask = apply_t5_attn_mask + self.warn_fp8_weights = False + def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List ): + if not self.warn_fp8_weights: + if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。" + ) + self.warn_fp8_weights = True + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy captions = [info.caption for info in infos] diff --git a/networks/lora_flux.py b/networks/lora_flux.py index fcb56a46..295267be 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -330,6 +330,11 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -344,6 +349,7 @@ def create_network( conv_alpha=conv_alpha, train_blocks=train_blocks, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, varbose=True, ) @@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh else: weights_sd = torch.load(file, map_location="cpu") - # get dim/alpha mapping + # get dim/alpha mapping, and train t5xxl modules_dim = {} modules_alpha = {} + train_t5xxl = None for key, value in weights_sd.items(): if "." not in key: continue @@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) + if train_t5xxl is None: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + # # split qkv # double_qkv_rank = None # single_qkv_rank = None @@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_alpha=modules_alpha, module_class=module_class, split_qkv=split_qkv, + train_t5xxl=train_t5xxl, ) return network, weights_sd @@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" - LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible def __init__( self, @@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module): modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, split_qkv: bool = False, + train_t5xxl: bool = False, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module): self.module_dropout = module_dropout self.train_blocks = train_blocks if train_blocks is not None else "all" self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module): logger.info( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) - if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) if self.split_qkv: logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") # create module instances def create_modules( @@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module): skipped_te = [] for i, text_encoder in enumerate(text_encoders): index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + logger.info(f"create LoRA for Text Encoder {index+1}:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # create LoRA for U-Net if self.train_blocks == "all": diff --git a/train_network.py b/train_network.py index 4204bce3..a68ccfcc 100644 --- a/train_network.py +++ b/train_network.py @@ -157,6 +157,9 @@ class NetworkTrainer: # region SD/SDXL + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False @@ -237,6 +240,13 @@ class NetworkTrainer: def is_text_encoder_not_needed_for_training(self, args): return False # use for sample images + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + # endregion def train(self, args): @@ -329,7 +339,7 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -428,12 +438,15 @@ class NetworkTrainer: ) args.scale_weight_norms = False + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: - # FIXME consider alpha of weights + # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -533,7 +546,7 @@ class NetworkTrainer: ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - + if not args.fp8_base_unet: accelerator.print("enable fp8 training for Text Encoder.") te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn @@ -545,17 +558,16 @@ class NetworkTrainer: unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: + for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) - elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + + # nn.Embedding not support FP8 + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: @@ -596,12 +608,12 @@ class NetworkTrainer: if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works if frag: - t_enc.text_model.embeddings.requires_grad_(True) + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() @@ -1028,8 +1040,12 @@ class NetworkTrainer: # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") - for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1085,11 +1101,7 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if ( - len(text_encoder_conds) == 0 - or text_encoder_conds[0] is None - or train_text_encoder - ): + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: From b7cff0a7548e5e33f735f06293ba24119fdaa585 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 21:35:47 +0900 Subject: [PATCH 166/356] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 84065570..c0acfa1d 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. +- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. - Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. From 56cb2fc885d818e9c4493fb2843870d7a141db1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Sep 2024 23:15:27 +0900 Subject: [PATCH 167/356] support T5XXL LoRA, reduce peak memory usage #1560 --- flux_minimal_inference.py | 73 +++++++++++++++++++++++++++++++-------- networks/lora_flux.py | 2 +- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 56c1b198..1c194e7c 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import datetime import math import os import random -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional import einops import numpy as np @@ -13,6 +13,7 @@ import torch from tqdm import tqdm from PIL import Image import accelerate +from transformers import CLIPTextModel from library import device_utils from library.device_utils import init_ipex, get_preferred_device @@ -125,7 +126,7 @@ def do_sample( def generate_image( model, - clip_l, + clip_l: CLIPTextModel, t5xxl, ae, prompt: str, @@ -141,12 +142,13 @@ def generate_image( # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, - dtype=dtype, + dtype=noise_dtype, generator=torch.Generator(device=device).manual_seed(seed), ) @@ -166,9 +168,48 @@ def generate_image( clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) with torch.no_grad(): - if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): - clip_l.to(clip_l_dtype) - t5xxl.to(t5xxl_dtype) + if is_fp8(clip_l_dtype): + param_itr = clip_l.parameters() + param_itr.__next__() # skip first + param_2nd = param_itr.__next__() + if param_2nd.dtype != clip_l_dtype: + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + text_encoder.fp8_prepared = True + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + with accelerator.autocast(): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask @@ -315,10 +356,10 @@ if __name__ == "__main__": t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl.eval() - if is_fp8(clip_l_dtype): - clip_l = accelerator.prepare(clip_l) - if is_fp8(t5xxl_dtype): - t5xxl = accelerator.prepare(t5xxl) + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) t5xxl_max_length = 256 if is_schnell else 512 tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) @@ -329,14 +370,16 @@ if __name__ == "__main__": model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype - if is_fp8(flux_dtype): - model = accelerator.prepare(model) + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") # AE ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae.eval() - if is_fp8(ae_dtype): - ae = accelerator.prepare(ae) + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) # LoRA lora_models: List[lora_flux.LoRANetwork] = [] @@ -360,7 +403,7 @@ if __name__ == "__main__": lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 295267be..ab9ccc4d 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) - if train_t5xxl is None: + if train_t5xxl is None or train_t5xxl is False: train_t5xxl = "lora_te3" in lora_name if train_t5xxl is None: From 90ed2dfb526168b2e77b8d367e928d8cc44b4278 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 08:39:29 +0900 Subject: [PATCH 168/356] feat: Add support for merging CLIP-L and T5XXL LoRA models --- README.md | 22 ++++- networks/flux_merge_lora.py | 180 ++++++++++++++++++++++++++++-------- 2 files changed, 162 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index c0acfa1d..fa81f6c0 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024: +The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. + Sep 4, 2024: - T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. - In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. @@ -276,7 +279,7 @@ CLIP-L LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint -`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__ +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__ ``` python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu @@ -284,13 +287,24 @@ python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. -`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`): +CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below. + +``` +--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors +``` + +FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency. + +An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory): - 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. - 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. -- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'. +- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. -In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. +`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py index 2e0d4c29..5e100a3b 100644 --- a/networks/flux_merge_lora.py +++ b/networks/flux_merge_lora.py @@ -2,6 +2,7 @@ import argparse import math import os import time +from typing import Any, Dict, Union import torch from safetensors import safe_open @@ -34,11 +35,11 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): +def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False): if dtype is not None: logger.info(f"converting to {dtype}...") for key in tqdm(list(state_dict.keys())): - if type(state_dict[key]) == torch.Tensor: + if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point: state_dict[key] = state_dict[key].to(dtype) logger.info(f"saving to: {file_name}") @@ -49,26 +50,76 @@ def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False): def merge_to_flux_model( - loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False + loading_device, + working_device, + flux_path: str, + clip_l_path: str, + t5xxl_path: str, + models, + ratios, + merge_dtype, + save_dtype, + mem_eff_load_save=False, ): # create module map without loading state_dict - logger.info(f"loading keys from FLUX.1 model: {flux_model}") lora_name_to_module_key = {} - with safe_open(flux_model, framework="pt", device=loading_device) as flux_file: - keys = list(flux_file.keys()) - for key in keys: - if key.endswith(".weight"): - module_name = ".".join(key.split(".")[:-1]) - lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") - lora_name_to_module_key[lora_name] = key + if flux_path is not None: + logger.info(f"loading keys from FLUX.1 model: {flux_path}") + with safe_open(flux_path, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + lora_name_to_clip_l_key = {} + if clip_l_path is not None: + logger.info(f"loading keys from clip_l model: {clip_l_path}") + with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file: + keys = list(clip_l_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_") + lora_name_to_clip_l_key[lora_name] = key + + lora_name_to_t5xxl_key = {} + if t5xxl_path is not None: + logger.info(f"loading keys from t5xxl model: {t5xxl_path}") + with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file: + keys = list(t5xxl_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_") + lora_name_to_t5xxl_key[lora_name] = key + + flux_state_dict = {} + clip_l_state_dict = {} + t5xxl_state_dict = {} if mem_eff_load_save: - flux_state_dict = {} - with MemoryEfficientSafeOpen(flux_model) as flux_file: - for key in tqdm(flux_file.keys()): - flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + if flux_path is not None: + with MemoryEfficientSafeOpen(flux_path) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + + if clip_l_path is not None: + with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file: + for key in tqdm(clip_l_file.keys()): + clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device) + + if t5xxl_path is not None: + with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file: + for key in tqdm(t5xxl_file.keys()): + t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device) else: - flux_state_dict = load_file(flux_model, device=loading_device) + if flux_path is not None: + flux_state_dict = load_file(flux_path, device=loading_device) + if clip_l_path is not None: + clip_l_state_dict = load_file(clip_l_path, device=loading_device) + if t5xxl_path is not None: + t5xxl_state_dict = load_file(t5xxl_path, device=loading_device) for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -81,8 +132,20 @@ def merge_to_flux_model( up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" - if lora_name not in lora_name_to_module_key: - logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.") + if lora_name in lora_name_to_module_key: + module_weight_key = lora_name_to_module_key[lora_name] + state_dict = flux_state_dict + elif lora_name in lora_name_to_clip_l_key: + module_weight_key = lora_name_to_clip_l_key[lora_name] + state_dict = clip_l_state_dict + elif lora_name in lora_name_to_t5xxl_key: + module_weight_key = lora_name_to_t5xxl_key[lora_name] + state_dict = t5xxl_state_dict + else: + logger.warning( + f"no module found for LoRA weight: {key}. Skipping..." + f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。" + ) continue down_weight = lora_sd.pop(key) @@ -93,11 +156,7 @@ def merge_to_flux_model( scale = alpha / dim # W <- W + U * D - module_weight_key = lora_name_to_module_key[lora_name] - if module_weight_key not in flux_state_dict: - weight = flux_file.get_tensor(module_weight_key) - else: - weight = flux_state_dict[module_weight_key] + weight = state_dict[module_weight_key] weight = weight.to(working_device, merge_dtype) up_weight = up_weight.to(working_device, merge_dtype) @@ -121,7 +180,7 @@ def merge_to_flux_model( # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale - flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + state_dict[module_weight_key] = weight.to(loading_device, save_dtype) del up_weight del down_weight del weight @@ -129,7 +188,7 @@ def merge_to_flux_model( if len(lora_sd) > 0: logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") - return flux_state_dict + return flux_state_dict, clip_l_state_dict, t5xxl_state_dict def merge_to_flux_model_diffusers( @@ -508,17 +567,28 @@ def merge(args): if save_dtype is None: save_dtype = merge_dtype - dest_dir = os.path.dirname(args.save_to) + assert ( + args.save_to or args.clip_l_save_to or args.t5xxl_save_to + ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください" + dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to) if not os.path.exists(dest_dir): logger.info(f"creating directory: {dest_dir}") os.makedirs(dest_dir) - if args.flux_model is not None: + if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None: if not args.diffusers: - state_dict = merge_to_flux_model( + assert (args.clip_l is None and args.clip_l_save_to is None) or ( + args.clip_l is not None and args.clip_l_save_to is not None + ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください" + assert (args.t5xxl is None and args.t5xxl_save_to is None) or ( + args.t5xxl is not None and args.t5xxl_save_to is not None + ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください" + flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model( args.loading_device, args.working_device, args.flux_model, + args.clip_l, + args.t5xxl, args.models, args.ratios, merge_dtype, @@ -526,7 +596,10 @@ def merge(args): args.mem_eff_load_save, ) else: - state_dict = merge_to_flux_model_diffusers( + assert ( + args.clip_l is None and args.t5xxl is None + ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません" + flux_state_dict = merge_to_flux_model_diffusers( args.loading_device, args.working_device, args.flux_model, @@ -536,8 +609,10 @@ def merge(args): save_dtype, args.mem_eff_load_save, ) + clip_l_state_dict = None + t5xxl_state_dict = None - if args.no_metadata: + if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0): sai_metadata = None else: merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) @@ -546,15 +621,24 @@ def merge(args): None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) - logger.info(f"saving FLUX model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + if flux_state_dict is not None and len(flux_state_dict) > 0: + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + + if clip_l_state_dict is not None and len(clip_l_state_dict) > 0: + logger.info(f"saving clip_l model to: {args.clip_l_save_to}") + save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save) + + if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0: + logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}") + save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) logger.info("calculating hashes and creating metadata...") - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash @@ -562,12 +646,12 @@ def merge(args): merged_from = sai_model_spec.build_merged_from(args.models) title = os.path.splitext(os.path.basename(args.save_to))[0] sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" ) metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, flux_state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -592,6 +676,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", ) + parser.add_argument( + "--clip_l", + type=str, + default=None, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl", + type=str, + default=None, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)", + ) parser.add_argument( "--mem_eff_load_save", action="store_true", @@ -617,6 +713,18 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", ) + parser.add_argument( + "--clip_l_save_to", + type=str, + default=None, + help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--t5xxl_save_to", + type=str, + default=None, + help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル", + ) parser.add_argument( "--models", type=str, From d9129522a6effea7077f18cdea0ee733a5ac7cb0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 12:20:07 +0900 Subject: [PATCH 169/356] set dtype before calling ae closes #1562 --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 32a36f03..0293b7be 100644 --- a/flux_train.py +++ b/flux_train.py @@ -651,7 +651,7 @@ def train(args): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"]) + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): From 2889108d858880589d362e06e98eeadf4682476a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Sep 2024 20:58:33 +0900 Subject: [PATCH 170/356] feat: Add --cpu_offload_checkpointing option to LoRA training --- README.md | 7 +++++++ flux_train.py | 2 +- flux_train_network.py | 5 +++++ train_network.py | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fa81f6c0..e8a12089 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 5, 2024 (update 1): + +Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + Sep 5, 2024: + The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. Sep 4, 2024: @@ -72,6 +77,8 @@ The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_ --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. + We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. The trained LoRA model can be used with ComfyUI. diff --git a/flux_train.py b/flux_train.py index 0293b7be..0edc83a9 100644 --- a/flux_train.py +++ b/flux_train.py @@ -261,7 +261,7 @@ def train(args): ) if args.gradient_checkpointing: - flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing) + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) flux.requires_grad_(True) diff --git a/flux_train_network.py b/flux_train_network.py index 2fc0f323..a6e57eed 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -50,6 +50,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert not args.split_mode or not args.cpu_offload_checkpointing, ( + "split_mode and cpu_offload_checkpointing cannot be used together" + " / split_modeとcpu_offload_checkpointingは同時に使用できません" + ) + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this def get_flux_model_name(self, args): diff --git a/train_network.py b/train_network.py index a68ccfcc..ad97491d 100644 --- a/train_network.py +++ b/train_network.py @@ -451,7 +451,11 @@ class NetworkTrainer: accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): if flag: if t_enc.supports_gradient_checkpointing: @@ -1281,6 +1285,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) From 0005867ba509d2e1a5674b267e8286b561c0ed71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Sep 2024 10:45:18 +0900 Subject: [PATCH 171/356] update README, format code --- README.md | 5 +++++ library/train_util.py | 4 ++-- library/utils.py | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81a54937..16ab80e7 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! + +- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! + - `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! + - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. diff --git a/library/train_util.py b/library/train_util.py index 102d39ed..1441e74f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2094,7 +2094,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2432,7 +2432,7 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index a219f6cb..5b7e657b 100644 --- a/library/utils.py +++ b/library/utils.py @@ -11,6 +11,7 @@ import cv2 from PIL import Image import numpy as np + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -80,8 +81,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation=Image.LANCZOS): pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # use Pillow resize @@ -92,6 +93,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 + # TODO make inf_utils.py From d29af146b8d4c4d028f8752657bd1349c8cd3509 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Sep 2024 23:01:15 +0900 Subject: [PATCH 172/356] add negative prompt for flux inference script --- README.md | 3 + flux_minimal_inference.py | 283 +++++++++++++++++++++++++++----------- 2 files changed, 203 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 2f010f49..126516f9 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 9, 2024: +Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. + Sep 5, 2024 (update 1): Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 1c194e7c..de607c52 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -71,22 +71,57 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): # this is ignored for schnell + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # prepare classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=b_vec, timesteps=t_vec, guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + txt_attention_mask=b_t5_attn_mask, ) + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred return img @@ -106,19 +141,48 @@ def do_sample( is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): + logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) return x @@ -135,6 +199,8 @@ def generate_image( image_height: int, steps: Optional[int], guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") @@ -162,65 +228,73 @@ def generate_image( # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + # prepare fp8 models + if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + # prepare embeddings logger.info("Encoding prompts...") - tokens_and_masks = tokenize_strategy.tokenize(prompt) clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) - with torch.no_grad(): - if is_fp8(clip_l_dtype): - param_itr = clip_l.parameters() - param_itr.__next__() # skip first - param_2nd = param_itr.__next__() - if param_2nd.dtype != clip_l_dtype: - logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") - clip_l.to(clip_l_dtype) # fp8 - clip_l.text_model.embeddings.to(dtype=torch.bfloat16) - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - if is_fp8(t5xxl_dtype): - if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): - logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - text_encoder.fp8_prepared = True - - t5xxl.to(t5xxl_dtype) - prepare_fp8(t5xxl.encoder, torch.bfloat16) - - with accelerator.autocast(): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) - else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check if torch.isnan(l_pooled).any(): @@ -244,7 +318,23 @@ def generate_image( t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, ) if args.offload: model = model.cpu() @@ -307,6 +397,8 @@ if __name__ == "__main__": parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", @@ -403,19 +495,34 @@ if __name__ == "__main__": lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: - generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) else: # loop for interactive width = target_width height = target_height steps = None guidance = args.guidance + cfg_scale = args.cfg_scale while True: print( "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": @@ -425,26 +532,36 @@ if __name__ == "__main__": options = prompt.split("--") prompt = options[0].strip() seed = None + negative_prompt = None for opt in options[1:]: - opt = opt.strip() - if opt.startswith("w"): - width = int(opt[1:].strip()) - elif opt.startswith("h"): - height = int(opt[1:].strip()) - elif opt.startswith("s"): - steps = int(opt[1:].strip()) - elif opt.startswith("d"): - seed = int(opt[1:].strip()) - elif opt.startswith("g"): - guidance = float(opt[1:].strip()) - elif opt.startswith("m"): - mutipliers = opt[1:].strip().split(",") - if len(mutipliers) != len(lora_models): - logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - continue - for i, lora_model in enumerate(lora_models): - lora_model.set_multiplier(float(mutipliers[i])) + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") - generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) logger.info("Done!") From d10ff62a78b15d0bb55f443cc2849c460300131b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 20:32:09 +0900 Subject: [PATCH 173/356] support individual LR for CLIP-L/T5XXL --- README.md | 4 +++ networks/lora_flux.py | 71 +++++++++++++++---------------------------- train_network.py | 32 ++++++++++++------- 3 files changed, 49 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 126516f9..b5799dd6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 10, 2024: +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. + Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ab9ccc4d..d540c221 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module): logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) - # if ( - # self.loraplus_lr_ratio is not None - # or self.loraplus_text_encoder_lr_ratio is not None - # or self.loraplus_unet_lr_ratio is not None - # ): - # assert ( - # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() - # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + if text_encoder_lr is None or len(text_encoder_lr) == 0: + text_encoder_lr = [default_lr, default_lr] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] self.requires_grad_(True) all_params = [] lr_descriptions = [] - def assemble_params(loras, lr, ratio): + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: + if loraplus_ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param @@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module): if lr is not None: if key == "plus": - param_data["lr"] = lr * ratio + param_data["lr"] = lr * loraplus_ratio else: param_data["lr"] = lr @@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module): return params, descriptions if self.text_encoder_loras: - params, descriptions = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, - ) - all_params.extend(params) - lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) if self.unet_loras: - # if self.block_lr: - # is_sdxl = False - # for lora in self.unet_loras: - # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: - # is_sdxl = True - # break - - # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - # block_idx_to_lora = {} - # for lora in self.unet_loras: - # idx = get_block_index(lora.lora_name, is_sdxl) - # if idx not in block_idx_to_lora: - # block_idx_to_lora[idx] = [] - # block_idx_to_lora[idx].append(lora) - - # # blockごとにパラメータを設定する - # for idx, block_loras in block_idx_to_lora.items(): - # params, descriptions = assemble_params( - # block_loras, - # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), - # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, - # ) - # all_params.extend(params) - # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) - - # else: params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, diff --git a/train_network.py b/train_network.py index ad97491d..e45db052 100644 --- a/train_network.py +++ b/train_network.py @@ -466,9 +466,17 @@ class NetworkTrainer: # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] @@ -476,11 +484,7 @@ class NetworkTrainer: trainable_params = results lr_descriptions = None except TypeError as e: - # logger.warning(f"{e}") - # accelerator.print( - # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - # ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: @@ -713,7 +717,7 @@ class NetworkTrainer: "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, @@ -760,8 +764,8 @@ class NetworkTrainer: "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, - "ss_fp8_base": args.fp8_base, - "ss_fp8_base_unet": args.fp8_base_unet, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), } self.update_metadata(metadata, args) # architecture specific metadata @@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) parser.add_argument( "--fp8_base_unet", action="store_true", From 65b8a064f6bb9a403374d4b08f4003037df42f8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Sep 2024 21:20:38 +0900 Subject: [PATCH 174/356] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b5799dd6..caea59b7 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The command to install PyTorch is as follows: ### Recent Updates Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -145,7 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. From fd68703f3795b3e9c75409ac5452807d056b928f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Wed, 11 Sep 2024 20:25:45 +0800 Subject: [PATCH 175/356] Add New lr scheduler (#1393) * add new lr scheduler * fix bugs and use num_cycles / 2 * Update requirements.txt * add num_cycles for min lr * keep PIECEWISE_CONSTANT * allow use float with warmup or decay ratio. * Update train_util.py --- library/train_util.py | 80 ++++++++++++++++++++++++++++++++++++++----- requirements.txt | 6 ++-- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c7b73ee3..340f6d64 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,8 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, DDPMScheduler, @@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith('%'): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + parser.add_argument( "--optimizer_type", type=str, @@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( "--lr_warmup_steps", - type=int, + type=int_or_float, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, + default=0, + help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: @@ -4332,13 +4369,13 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + name = SchedulerType(name) or DiffusersSchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == SchedulerType.PIECEWISE_CONSTANT: + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs # All other schedulers require `num_warmup_steps` @@ -4348,6 +4385,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") @@ -4366,7 +4406,31 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index 977c5cd9..d2a2fbb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.30.0 +transformers==4.41.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.23.3 # for Image utils imagesize==1.4.1 # for BLIP captioning From 6dbfd47a59cdb91be2077e1d0dec0f94698348dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 21:44:36 +0900 Subject: [PATCH 176/356] Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393 --- README.md | 9 ++++++ library/train_util.py | 66 ++++++++++++++++++++++++++++--------------- requirements.txt | 4 +-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 16ab80e7..011141bf 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. + - transformers, accelerate and huggingface_hub are updated. + - If you encounter any issues, please report them. + +- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! + - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. + - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. + +https://github.com/kohya-ss/sd-scripts/pull/1393 - When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! - Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! diff --git a/library/train_util.py b/library/train_util.py index 340f6d64..e65760ba 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,10 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -2974,7 +2977,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): def int_or_float(value): - if value.endswith('%'): + if value.endswith("%"): try: return float(value[:-1]) / 100.0 except ValueError: @@ -3041,13 +3044,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--lr_warmup_steps", type=int_or_float, default=0, - help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_decay_steps", type=int_or_float, default=0, - help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3071,13 +3076,16 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--lr_scheduler_timescale", type=int, default=None, - help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + , ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, - help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) @@ -4327,8 +4335,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps - num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps - num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -4369,15 +4381,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) or DiffusersSchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -4408,11 +4422,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if name == SchedulerType.COSINE_WITH_MIN_LR: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, num_cycles=num_cycles / 2, - min_lr_rate=min_lr_ratio, + min_lr_rate=min_lr_ratio, **lr_scheduler_kwargs, ) @@ -4421,16 +4435,22 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") if name == SchedulerType.WARMUP_STABLE_DECAY: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_stable_steps=num_stable_steps, - num_decay_steps=num_decay_steps, - num_cycles=num_cycles / 2, + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, **lr_scheduler_kwargs, ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index d2a2fbb8..15e6e58f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate==0.30.0 -transformers==4.41.2 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.23.3 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning From 8311e88225fef377591e5be19eb1f50fe7a2941f Mon Sep 17 00:00:00 2001 From: cocktailpeanut Date: Wed, 11 Sep 2024 09:02:29 -0400 Subject: [PATCH 177/356] typo fix --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c38864fe..f682dcbf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3355,15 +3355,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From c7c666b1829a7c1f3435558efa425b08b50fab41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:12:31 +0900 Subject: [PATCH 178/356] fix typo --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e65760ba..a46d9487 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3077,15 +3077,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From a823fd9fb8d219b5b4c57df12eed41ae34fdf843 Mon Sep 17 00:00:00 2001 From: Plat <60182057+p1atdev@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:21:16 +0900 Subject: [PATCH 179/356] Improve wandb logging (#1576) * fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified * fix: checking of whether wandb is enabled * feat: log images to wandb with their positive prompt as captions * feat: logging sample images' caption for sd3 and flux * fix: import wandb before use --- fine_tune.py | 7 +++++-- flux_train.py | 7 +++++-- library/flux_train_utils.py | 20 +++++++++++--------- library/sd3_train_utils.py | 20 +++++++++++--------- library/train_util.py | 20 +++++++++++--------- sd3_train.py | 7 +++++-- sdxl_train.py | 7 +++++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_controlnet.py | 7 +++++-- train_db.py | 7 +++++-- train_network.py | 7 +++++-- train_textual_inversion.py | 8 ++++++-- train_textual_inversion_XTI.py | 4 ++-- 14 files changed, 80 insertions(+), 49 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c9102f6c..fb6b3ed6 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -337,6 +337,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -456,7 +459,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -469,7 +472,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/flux_train.py b/flux_train.py index 0edc83a9..33481df8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -629,6 +629,9 @@ def train(args): # For --sample_at_first flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 @@ -777,7 +780,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) @@ -791,7 +794,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0b5d4d90..f77d4b58 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -254,17 +254,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) def time_shift(mu: float, sigma: float, t: torch.Tensor): diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index da072950..e819d440 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -604,17 +604,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # region Diffusers diff --git a/library/train_util.py b/library/train_util.py index f682dcbf..742d057e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5832,17 +5832,19 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log( + {f"sample_{i}": wandb.Image( + image, + caption=prompt # positive prompt as a caption + )}, + commit=False + ) # endregion diff --git a/sd3_train.py b/sd3_train.py index 87011b21..5120105f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -682,6 +682,9 @@ def train(args): # For --sample_at_first sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # following function will be moved to sd3_train_utils @@ -901,7 +904,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) @@ -915,7 +918,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train.py b/sdxl_train.py index b2c62dd1..7291ddd2 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -617,6 +617,9 @@ def train(args): sdxl_train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -797,7 +800,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} if block_lrs is None: train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) @@ -814,7 +817,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0eaec29b..9d1cfc63 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -541,14 +541,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463..6fa1d609 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -480,14 +480,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a..57f0d263 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -409,6 +409,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -542,14 +545,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_db.py b/train_db.py index 7caee664..d42afd89 100644 --- a/train_db.py +++ b/train_db.py @@ -315,6 +315,9 @@ def train(args): train_util.sample_images( accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -445,7 +448,7 @@ def train(args): ) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -458,7 +461,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_network.py b/train_network.py index e45db052..34385ae0 100644 --- a/train_network.py +++ b/train_network.py @@ -1038,6 +1038,9 @@ class NetworkTrainer: # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1224,7 +1227,7 @@ class NetworkTrainer: if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm ) @@ -1233,7 +1236,7 @@ class NetworkTrainer: if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9044f50d..956c7860 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -550,6 +550,9 @@ class TextualInversionTrainer: unet, prompt_replacement, ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -684,7 +687,7 @@ class TextualInversionTrainer: remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -702,7 +705,7 @@ class TextualInversionTrainer: if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -739,6 +742,7 @@ class TextualInversionTrainer: unet, prompt_replacement, ) + accelerator.log({}) # end of epoch diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137..ca0b603f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -538,7 +538,7 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -556,7 +556,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) From 237317fffd060bcfb078b770ccd2df18bc4dd3a6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:23:43 +0900 Subject: [PATCH 180/356] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 2b3d0d5a..d3481b6a 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 11, 2024: +Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! + Sep 10, 2024: In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. From cefe52629e1901dd8192b0487afd5e9f089e3519 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 12 Sep 2024 12:36:07 +0900 Subject: [PATCH 181/356] fix to work old notation for TE LR in .toml --- networks/lora_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d540c221..dd267de0 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -788,8 +788,11 @@ class LoRANetwork(torch.nn.Module): def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): # make sure text_encoder_lr as list of two elements - if text_encoder_lr is None or len(text_encoder_lr) == 0: + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float): + text_encoder_lr = [text_encoder_lr, text_encoder_lr] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From 1d7118a62268f12ebfd81c10db53bd85ef9d7631 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Fri, 13 Sep 2024 19:01:36 +0900 Subject: [PATCH 182/356] Support : OFT merge to base model (#1580) * Support : OFT merge to base model * Fix typo * Fix typo_2 * Delete unused parameter 'eye' --- networks/sdxl_merge_lora.py | 190 +++++++++++++++++++++++++++--------- 1 file changed, 143 insertions(+), 47 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 3383a80d..2c998c8c 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,10 +8,12 @@ from tqdm import tqdm from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +import concurrent.futures def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) +def detect_method_from_training_model(models, dtype): + for model in models: + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) + + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + logger.info(f"method:{method}") # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == 'LoRA': + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == 'OFT': + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + target_replace_modules = ( + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - + if method == 'LoRA': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + elif method == 'OFT': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + name_to_module[oft_name] = child_module + + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if method == 'LoRA': + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + + elif method == 'OFT': + + multiplier=1.0 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue + return module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + oft_blocks = lora_sd[key] + + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels + + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) + with concurrent.futures.ThreadPoolExecutor() as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) + def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model From 57ae44eb6138fe4a3864fffa62090f9d0113417d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:00 +0900 Subject: [PATCH 183/356] refactor to make safer --- networks/sdxl_merge_lora.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 2c998c8c..d5a54e02 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -44,11 +44,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) - for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' - elif "oft_blocks" in key: - return 'OFT' + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) @@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ ) elif method == 'OFT': prefix = oft.OFTNetwork.OFT_PREFIX_UNET + # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) @@ -83,17 +84,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if method == 'LoRA': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - elif method == 'OFT': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - oft_name = prefix + "." + name + "." + child_name - oft_name = oft_name.replace(".", "_") - name_to_module[oft_name] = child_module - + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") @@ -168,6 +163,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: + logger.info(f"no module found for OFT weight: {key}") return module = name_to_module[module_name] @@ -208,7 +204,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - with concurrent.futures.ThreadPoolExecutor() as executor: + # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) From 3387dc7306087b84646666e49323980c89d14945 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 19:45:42 +0900 Subject: [PATCH 184/356] formatting, update README --- README.md | 6 +++ networks/sdxl_merge_lora.py | 86 +++++++++++++++++++++---------------- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index fd81a781..d5d2a7f7 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Sep 13, 2024 / 2024-09-13: + +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. + +- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 + ### Jun 23, 2024 / 2024-06-23: - Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5a54e02..d5b6f7f3 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -10,11 +10,14 @@ import library.model_util as model_util import lora import oft from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) import concurrent.futures + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -41,20 +44,22 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) + def detect_method_from_training_model(models, dtype): for model in models: lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): - if 'lora_up' in key or 'lora_down' in key: - return 'LoRA' + if "lora_up" in key or "lora_down" in key: + return "LoRA" elif "oft_blocks" in key: - return 'OFT' + return "OFT" + def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) - + # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") @@ -62,7 +67,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if method == 'LoRA': + if method == "LoRA": if i <= 1: if i == 0: prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 @@ -72,9 +77,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: prefix = lora.LoRANetwork.LORA_PREFIX_UNET target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) - elif method == 'OFT': + elif method == "OFT": prefix = oft.OFTNetwork.OFT_PREFIX_UNET # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( @@ -88,15 +93,14 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - - + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - if method == 'LoRA': + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -139,12 +143,11 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) - - elif method == 'OFT': - - multiplier=1.0 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - + elif method == "OFT": + + multiplier = 1.0 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for key in tqdm(lora_sd.keys()): if "oft_blocks" in key: oft_blocks = lora_sd[key] @@ -154,12 +157,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ if "alpha" in key: oft_blocks = lora_sd[key] alpha = oft_blocks.item() - break - + break + def merge_to(key): if "alpha" in key: return - + # find original module for this OFT module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: @@ -168,18 +171,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module = name_to_module[module_name] # logger.info(f"apply {key} to {module}") - + oft_blocks = lora_sd[key] - + if isinstance(module, torch.nn.Linear): out_dim = module.out_features elif isinstance(module, torch.nn.Conv2d): out_dim = module.out_channels - + num_blocks = dim block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim - + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -188,24 +191,24 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) - + # get org weight org_sd = module.state_dict() org_weight = org_sd["weight"].to(device) R = R.to(org_weight.device, dtype=org_weight.dtype) - + if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: weight = torch.einsum("oi, op -> pi", org_weight, R) - - weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor - + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor + module.weight = torch.nn.Parameter(weight) # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough - max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) @@ -258,7 +261,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue - + if "lora_up" in key and concat: concat_dim = 1 elif "lora_down" in key and concat: @@ -272,8 +275,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -295,7 +298,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): dim = merged_sd[key_down].shape[0] perm = torch.randperm(dim) merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -323,7 +326,9 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" def str_to_dtype(p): if p == "float": @@ -410,10 +415,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument( @@ -431,8 +442,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) return parser From 734d2e5b2b7a1551f3750a15e71060f3beed98e9 Mon Sep 17 00:00:00 2001 From: terracottahaniwa <57107346+terracottahaniwa@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:45:35 +0900 Subject: [PATCH 185/356] Support Lora Block Weight (LBW) to svd_merge_lora.py (#1575) * support lora block weight * solve license incompatibility * Fix issue: lbw index calculation --- networks/svd_merge_lora.py | 150 ++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 4 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index cb00a600..6e163aec 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,5 +1,8 @@ import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -14,6 +17,106 @@ logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False +} + +LAYER17 = { + "BASE": True, + "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, + "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, +} + +LAYER26 = { + "BASE": True, + "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, + "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "MID": True, + "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, + "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 + block_idx = -1 # invalid lora name + if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + up_down = g[0] + i = int(g[1]) + j = int(g[3]) + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name + + if g[0] == "down": + block_idx = 1 + idx # 1-based index, down block index + elif g[0] == "up": + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + + elif "mid_block_" in lora_name: + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block + else: + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10 + block_idx = 10 + elif name.startswith("output_blocks_"): # 0-8 to 11-19 + block_idx = 11 + int(name.split("_")[2]) + elif name.startswith("out_"): # 20, No LoRA in sd-scripts + block_idx = 20 + + return block_idx + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -42,12 +145,34 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None - for model, ratio in zip(models, ratios): + + if lbws: + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) @@ -57,6 +182,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + print(dict(zip(LAYER26.keys(), lbw_weights))) + # merge logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): @@ -93,6 +224,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -170,6 +307,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -187,7 +328,7 @@ def merge(args): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) logger.info(f"calculating hashes and creating metadata...") @@ -237,6 +378,7 @@ def setup_parser() -> argparse.ArgumentParser: "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank", From f4a0bea6dce152e2210f611f94acfdfaa72068fe Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:26:06 +0900 Subject: [PATCH 186/356] format by black --- networks/svd_merge_lora.py | 188 +++++++++++++++++++++++++++++-------- 1 file changed, 147 insertions(+), 41 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 6e163aec..0decd904 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -11,8 +11,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -22,38 +24,118 @@ SDXL_LAYER_NUM = [12, 20] LAYER12 = { "BASE": True, - "IN00": False, "IN01": False, "IN02": False, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": False, + "IN02": False, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": False, "OUT07": False, "OUT08": False, "OUT09": False, "OUT10": False, "OUT11": False + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": False, + "OUT07": False, + "OUT08": False, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER17 = { "BASE": True, - "IN00": False, "IN01": True, "IN02": True, "IN03": False, "IN04": True, "IN05": True, - "IN06": False, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": False, + "IN01": True, + "IN02": True, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": False, "OUT01": False, "OUT02": False, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": False, + "OUT01": False, + "OUT02": False, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } LAYER20 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": False, "IN10": False, "IN11": False, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": False, "OUT10": False, "OUT11": False, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": False, + "OUT10": False, + "OUT11": False, } LAYER26 = { "BASE": True, - "IN00": True, "IN01": True, "IN02": True, "IN03": True, "IN04": True, "IN05": True, - "IN06": True, "IN07": True, "IN08": True, "IN09": True, "IN10": True, "IN11": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": True, + "IN10": True, + "IN11": True, "MID": True, - "OUT00": True, "OUT01": True, "OUT02": True, "OUT03": True, "OUT04": True, "OUT05": True, - "OUT06": True, "OUT07": True, "OUT08": True, "OUT09": True, "OUT10": True, "OUT11": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, } assert len([v for v in LAYER12.values() if v]) == 12 @@ -145,6 +227,33 @@ def save_to_file(file_name, state_dict, dtype, metadata): torch.save(state_dict, file_name) +def format_lbws(lbws): + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all( + len(lbw) in ACCEPTABLE for lbw in lbws + ), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all( + all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws + ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + return lbws, is_sdxl, LBW_TARGET_IDX + + def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} @@ -152,25 +261,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer base_model = None if lbws: - try: - # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している - lbws = [json.loads(lbw) for lbw in lbws] - except Exception: - raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") - assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" - assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" - assert all(len(lbw) in ACCEPTABLE for lbw in lbws), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" - assert all(all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" - - layer_num = len(lbws[0]) - is_sdxl = True if layer_num in SDXL_LAYER_NUM else False - FLAGS = { - "12": LAYER12.values(), - "17": LAYER17.values(), - "20": LAYER20.values(), - "26": LAYER26.values(), - }[str(layer_num)] - LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws) + else: + is_sdxl = False + LBW_TARGET_IDX = [] for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") @@ -186,7 +280,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer lbw_weights = [1] * 26 for index, value in zip(LBW_TARGET_IDX, lbw): lbw_weights[index] = value - print(dict(zip(LAYER26.keys(), lbw_weights))) + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") # merge logger.info(f"merging...") @@ -306,9 +400,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" if args.lbws: - assert len(args.models) == len(args.lbws), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" else: args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく @@ -372,10 +470,16 @@ def setup_parser() -> argparse.ArgumentParser: help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") @@ -386,7 +490,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) parser.add_argument( "--no_metadata", action="store_true", From b755ebd0a4dd2967171b6b5909624325359a2aa0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 21:29:31 +0900 Subject: [PATCH 187/356] add LBW support for SDXL merge LoRA --- README.md | 12 +++++- networks/sdxl_merge_lora.py | 75 ++++++++++++++++++++++++++++++++----- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index d5d2a7f7..0be2f9a7 100644 --- a/README.md +++ b/README.md @@ -139,9 +139,17 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Sep 13, 2024 / 2024-09-13: -- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). Will be included in the next release. +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). +- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details. +- `sdxl_merge_lora.py` also supports LBW. +- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW. +- These will be included in the next release. -- `sdxl_merge_lora.py` が OFT をサポートしました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。次のリリースに含まれます。 +- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。 +- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。 +- `sdxl_merge_lora.py` でも LBW がサポートされました。 +- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。 +- 以上は次回リリースに含まれます。 ### Jun 23, 2024 / 2024-06-23: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index d5b6f7f3..62f5a87d 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -1,7 +1,9 @@ +import itertools import math import argparse import os import time +import concurrent.futures import torch from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -9,13 +11,13 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora import oft +from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26 from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) -import concurrent.futures def load_state_dict(file_name, dtype): @@ -47,6 +49,7 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): def detect_method_from_training_model(models, dtype): for model in models: + # TODO It is better to use key names to detect the method lora_sd, _ = load_state_dict(model, dtype) for key in tqdm(lora_sd.keys()): if "lora_up" in key or "lora_down" in key: @@ -55,15 +58,20 @@ def detect_method_from_training_model(models, dtype): return "OFT" -def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): - text_encoder1.to(merge_dtype) +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype): text_encoder1.to(merge_dtype) + text_encoder2.to(merge_dtype) unet.to(merge_dtype) # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): @@ -94,12 +102,18 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if method == "LoRA": for key in tqdm(lora_sd.keys()): if "lora_down" in key: @@ -121,6 +135,12 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ alpha = lora_sd.get(alpha_key, dim) scale = alpha / dim + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + # W <- W + U * D weight = module.weight # logger.info(module_name, down_weight.size(), up_weight.size()) @@ -145,7 +165,6 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ elif method == "OFT": - multiplier = 1.0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for key in tqdm(lora_sd.keys()): @@ -183,6 +202,13 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim + multiplier = 1 + if lbw: + index = get_lbw_block_index(key, False) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + multiplier *= lbw_weights[index] + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) @@ -213,17 +239,35 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): +def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + if method == "OFT": + raise ValueError( + "OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません" + ) + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + merged_sd = {} v2 = None base_model = None - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if lora_metadata is not None: if v2 is None: v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず @@ -277,6 +321,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): scale = math.sqrt(alpha / base_alpha) * ratio scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -329,6 +379,12 @@ def merge(args): assert len(args.models) == len( args.ratios ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -356,7 +412,7 @@ def merge(args): ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") - merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype) if args.no_metadata: sai_metadata = None @@ -372,7 +428,7 @@ def merge(args): args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) logger.info(f"calculating hashes and creating metadata...") @@ -427,6 +483,7 @@ def setup_parser() -> argparse.ArgumentParser: help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument( "--no_metadata", action="store_true", From 93d9fbf60761fc1158e37f45f0d0c142913d70f5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 22:37:11 +0900 Subject: [PATCH 188/356] improve OFT implementation closes #944 --- README.md | 26 +++++++++- gen_img.py | 3 +- networks/check_lora_weights.py | 2 +- networks/oft.py | 94 ++++++++++++++++++++++------------ 4 files changed, 88 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 0130ccff..def528a2 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. -- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! +- Improvements in OFT (Orthogonal Finetuning) Implementation + 1. Optimization of Calculation Order: + - Changed the calculation order in the forward method from (Wx)R to W(xR). + - This has improved computational efficiency and processing speed. + 2. Correction of Bias Application: + - In the previous implementation, R was incorrectly applied to the bias. + - The new implementation now correctly handles bias by using F.conv2d and F.linear. + 3. Efficiency Enhancement in Matrix Operations: + - Introduced einsum in both the forward and merge_to methods. + - This has optimized matrix operations, resulting in further speed improvements. + 4. Proper Handling of Data Types: + - Improved to use torch.float32 during calculations and convert results back to the original data type. + - This maintains precision while ensuring compatibility with the original model. + 5. Unified Processing for Conv2d and Linear Layers: + - Implemented a consistent method for applying OFT to both layer types. + - These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability. + + - Additional Information + * Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation. + + * Performance Improvement: Training speed has been improved by approximately 30%. + + * Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL). + +- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. diff --git a/gen_img.py b/gen_img.py index d0a8f814..59bcd5b0 100644 --- a/gen_img.py +++ b/gen_img.py @@ -86,7 +86,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c9..f8eab53b 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/oft.py b/networks/oft.py index 461a9869..6321def3 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -4,13 +4,17 @@ import math import os from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL +import einops from transformers import CLIPTextModel import numpy as np import torch +import torch.nn.functional as F import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module): if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - self.constraint = alpha * out_dim + + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility + # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu self.out_dim = out_dim self.shape = org_module.weight.shape @@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I - R = torch.block_diag(*block_R_weighted) - - return R + if self.I.device != block_Q.device: + self.I = self.I.to(block_Q.device) + I = self.I + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + return block_R_weighted def forward(self, x, scale=None): - x = self.org_forward(x) if self.multiplier == 0.0: - return x + return self.org_forward(x) + org_module = self.org_module[0] + org_dtype = x.dtype - R = self.get_weight().to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + R = self.get_weight().to(torch.float32) + W = org_module.weight.to(torch.float32) + + if len(W.shape) == 4: # Conv2d + W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped) + RW = einops.rearrange(RW, "k m ... -> (k m) ...") + result = F.conv2d( + x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups + ) + else: # Linear + W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped) + RW = einops.rearrange(RW, "k m p -> (k m) p") + result = F.linear(x, RW.to(org_dtype), org_module.bias) + return result class OFTInfModule(OFTModule): @@ -115,18 +133,19 @@ class OFTInfModule(OFTModule): return self.org_forward(x) return super().forward(x, scale) - def merge_to(self, multiplier=None, sign=1): - R = self.get_weight(multiplier) * sign - + def merge_to(self, multiplier=None): # get org weight org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - R = R.to(org_weight.device, dtype=org_weight.dtype) + org_weight = org_sd["weight"].to(torch.float32) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + R = self.get_weight(multiplier).to(torch.float32) + + weight = org_weight.reshape(self.num_blocks, self.block_size, -1) + weight = torch.einsum("k n m, k n ... -> k m ...", R, weight) + weight = weight.reshape(org_weight.shape) + + # convert back to original dtype + weight = weight.to(org_sd["weight"].dtype) # set weight to org_module org_sd["weight"] = weight @@ -145,8 +164,16 @@ def create_network( ): if network_dim is None: network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) enable_all_linear = kwargs.get("enable_all_linear", None) enable_conv = kwargs.get("enable_conv", None) @@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: if dim is None: dim = param.size()[0] - if has_conv2d is None and param.dim() == 4: + if has_conv2d is None and "in_layers_2" in name: has_conv2d = True - if all_linear is None: - if param.dim() == 3 and "attn" not in name: - all_linear = True - if dim is not None and alpha is not None and has_conv2d is not None: + if all_linear is None and "_ff_" in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None: break if has_conv2d is None: has_conv2d = False @@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module): self.alpha = alpha logger.info( - f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}" ) # create module instances From 2d8ee3c28007393386528cfeec0a9b714dafd85b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 15:48:16 +0900 Subject: [PATCH 189/356] OFT for FLUX.1 --- flux_minimal_inference.py | 20 +- networks/lora_flux.py | 6 +- networks/oft.py | 2 +- networks/oft_flux.py | 482 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 504 insertions(+), 6 deletions(-) create mode 100644 networks/oft_flux.py diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index de607c52..2f1b9a37 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -14,9 +14,11 @@ from tqdm import tqdm from PIL import Image import accelerate from transformers import CLIPTextModel +from safetensors.torch import load_file from library import device_utils from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux init_ipex() @@ -405,7 +407,7 @@ if __name__ == "__main__": type=str, nargs="*", default=[], - help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) @@ -482,9 +484,19 @@ if __name__ == "__main__": else: multiplier = 1.0 - lora_model, weights_sd = lora_flux.create_network_from_weights( - multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True - ) + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + if args.merge_lora_weights: lora_model.merge_to([clip_l, t5xxl], model, weights_sd) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index dd267de0..ea7df8b4 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -41,7 +41,11 @@ class LoRAModule(torch.nn.Module): module_dropout=None, split_dims: Optional[List[int]] = None, ): - """if alpha == 0 or None, alpha is rank (no scaling).""" + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ super().__init__() self.lora_name = lora_name diff --git a/networks/oft.py b/networks/oft.py index 6321def3..0c3a5393 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -51,7 +51,7 @@ class OFTModule(torch.nn.Module): alpha = alpha.detach().numpy() # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility - # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha self.constraint = alpha * out_dim self.register_buffer("alpha", torch.tensor(alpha)) diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 00000000..27b8b637 --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False From c9ff4de90597e933b441502d45c175fe46b99714 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:17:52 +0900 Subject: [PATCH 190/356] Add support for specifying rank for each layer in FLUX.1 --- README.md | 61 ++++++++++++++++++++++++ networks/lora_flux.py | 107 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 161 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6e32fa31..9a979479 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 14, 2024: +- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. +- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. + Sep 11, 2024: Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! @@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) +- [FLUX.1 OFT training](#flux1-oft-training) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/ The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. +#### Specify rank for each layer in FLUX.1 + +You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|img_attn_dim|img_attn in DoubleStreamBlock| +|txt_attn_dim|txt_attn in DoubleStreamBlock| +|img_mlp_dim|img_mlp in DoubleStreamBlock| +|txt_mlp_dim|txt_mlp in DoubleStreamBlock| +|img_mod_dim|img_mod in DoubleStreamBlock| +|txt_mod_dim|txt_mod in DoubleStreamBlock| +|single_dim|linear1 and linear2 in SingleStreamBlock| +|single_mod_dim|modulation in SingleStreamBlock| + +example: +``` +--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +``` + +You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "in_dims=[4,2,2,2,4]" +``` + +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. + +### FLUX.1 OFT training + +You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. + +- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`. +- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc. +- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it. +- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`. +- `--network_args` specifies the hyperparameters of OFT. The following are valid: + - Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention. + +Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`). + +Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1. + +``` +--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3 +--network_args "enable_all_linear=True" --learning_rate 1e-5 +``` + +The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer. + ### Inference for FLUX.1 with LoRA model The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ea7df8b4..a34cde1a 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -316,6 +316,44 @@ def create_network( else: conv_alpha = float(conv_alpha) + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -339,6 +377,11 @@ def create_network( if train_t5xxl is not None: train_t5xxl = True if train_t5xxl == "True" else False + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -354,7 +397,9 @@ def create_network( train_blocks=train_blocks, split_qkv=split_qkv, train_t5xxl=train_t5xxl, - varbose=True, + type_dims=type_dims, + in_dims=in_dims, + verbose=verbose, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) @@ -462,7 +507,9 @@ class LoRANetwork(torch.nn.Module): train_blocks: Optional[str] = None, split_qkv: bool = False, train_t5xxl: bool = False, - varbose: Optional[bool] = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + verbose: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier @@ -478,12 +525,17 @@ class LoRANetwork(torch.nn.Module): self.split_qkv = split_qkv self.train_t5xxl = train_t5xxl + self.type_dims = type_dims + self.in_dims = in_dims + self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info( @@ -502,7 +554,12 @@ class LoRANetwork(torch.nn.Module): # create module instances def create_modules( - is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_FLUX @@ -513,16 +570,22 @@ class LoRANetwork(torch.nn.Module): loras = [] skipped = [] for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name + lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") + if filter is not None and not filter in lora_name: + continue + dim = None alpha = None @@ -534,8 +597,25 @@ class LoRANetwork(torch.nn.Module): else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: - dim = self.lora_dim + dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha + + if type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d + break + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha @@ -566,6 +646,9 @@ class LoRANetwork(torch.nn.Module): split_dims=split_dims, ) loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched return loras, skipped # create LoRA for text encoder @@ -594,10 +677,20 @@ class LoRANetwork(torch.nn.Module): self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: + if verbose and len(skipped) > 0: logger.warning( f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) From 6445bb2bc974cec51256ae38c1be0900e90e6f87 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Sep 2024 22:37:26 +0900 Subject: [PATCH 191/356] update README --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9a979479..c94ea359 100644 --- a/README.md +++ b/README.md @@ -213,10 +213,12 @@ When network_args is not specified, the default value (`network_dim`) is applied |single_dim|linear1 and linear2 in SingleStreamBlock| |single_mod_dim|modulation in SingleStreamBlock| +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + example: ``` --network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" -"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True" ``` You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. From 9f44ef133083c530874c6cf022a4de8fda3edae2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:52:23 +0900 Subject: [PATCH 192/356] add diffusers to FLUX.1 conversion script --- README.md | 19 ++- tools/convert_diffusers_to_flux.py | 223 +++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 tools/convert_diffusers_to_flux.py diff --git a/README.md b/README.md index c94ea359..7d6c336e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 15, 2024: + +Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. + +The implementation is based on 2kpr's code. Thanks to 2kpr! + Sep 14, 2024: - You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. - OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. @@ -57,6 +63,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [Convert FLUX LoRA](#convert-flux-lora) - [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) - [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) +- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1) ### FLUX.1 LoRA training @@ -355,7 +362,7 @@ If you use LoRA in the inference environment, converting it to AI-toolkit format Note that re-conversion will increase the size of LoRA. -CLIP-L LoRA is not supported. +CLIP-L/T5XXL LoRA is not supported. ### Merge LoRA to FLUX.1 checkpoint @@ -435,6 +442,16 @@ resolution = [512, 512] num_repeats = 1 ``` +### Convert Diffusers to FLUX.1 + +Script: `convert_diffusers_to_flux1.py` + +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. + +``` +python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 +``` + ## SD3 training SD3 training is done with `sd3_train.py`. diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py new file mode 100644 index 00000000..9d8f7c74 --- /dev/null +++ b/tools/convert_diffusers_to_flux.py @@ -0,0 +1,223 @@ +# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model. +# It is based on the implementation by 2kpr. Thanks to 2kpr! +# Major changes: +# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once. +# - Makes reverse map from diffusers map to avoid loading all tensors. +# - Removes dependency on .json file for weights mapping. +# - Adds support for custom memory efficient load and save functions. +# - Supports saving with different precision. +# - Supports .safetensors file as input. + +# Copyright 2024 2kpr. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import os +from pathlib import Path +import safetensors +from safetensors.torch import safe_open +import torch +from tqdm import tqdm + +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def convert(args): + # if diffusers_path is folder, get safetensors file + diffusers_path = Path(args.diffusers_path) + if diffusers_path.is_dir(): + diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + + flux_path = Path(args.save_to) + if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + + if not diffusers_path.exists(): + logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}") + return + + mem_eff_flag = args.mem_eff_load_save + save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None + + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for i in range(3): + # replace 00001 with 0000i + current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}")) + logger.info(f"Loading diffusers file: {current_diffusers_path}") + + open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt")) + with open_func(current_diffusers_path) as f: + for diffusers_key in tqdm(f.keys()): + if diffusers_key in diffusers_to_bfl_map: + tensor = f.get_tensor(diffusers_key).to("cpu") + if save_dtype is not None: + tensor = tensor.to(save_dtype) + + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + return + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + # save flux_sd to safetensors file + logger.info(f"Saving Flux safetensors file: {flux_path}") + if mem_eff_flag: + mem_eff_save_file(flux_sd, flux_path) + else: + safetensors.torch.save_file(flux_sd, flux_path) + + logger.info("Conversion completed.") + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--diffusers_path", + default=None, + type=str, + required=True, + help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file." + " / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス", + ) + parser.add_argument( + "--save_to", + default=None, + type=str, + required=True, + help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, default is same as loading precision" + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + convert(args) From be078bdaca41084a20edb952b98a82f3e05d2dad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Sep 2024 13:59:17 +0900 Subject: [PATCH 193/356] fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7d6c336e..f79fe21a 100644 --- a/README.md +++ b/README.md @@ -446,7 +446,7 @@ resolution = [512, 512] Script: `convert_diffusers_to_flux1.py` -Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `transfomer` folder. +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder. ``` python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 From 96c677b4594ed6f28f3ef896f6deca7c3aced25d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 10:42:09 +0900 Subject: [PATCH 194/356] fix to work lienar/cosine lr scheduler closes #1602 ref #1393 --- library/train_util.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 742d057e..60afd421 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4707,6 +4707,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") @@ -5837,14 +5846,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # endregion From d8d15f1a7e09ca217930288b41bd239881126b93 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 16 Sep 2024 23:14:09 +0900 Subject: [PATCH 195/356] add support for specifying blocks in FLUX.1 LoRA training --- README.md | 24 ++++++++++++- networks/lora_flux.py | 82 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f79fe21a..24217d8b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 16, 2024: + + Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. + Sep 15, 2024: Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. @@ -54,9 +58,12 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. ` - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - - [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) + - [Distribution of timesteps](#distribution-of-timesteps) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) + - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) + - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) - [FLUX.1 OFT training](#flux1-oft-training) +- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) @@ -239,6 +246,21 @@ Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. +#### Specify blocks to train in FLUX.1 LoRA training + +You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" +``` + +``` +--network_args "train_double_block_indices=none" "train_single_block_indices=10-15" +``` + +If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index a34cde1a..f549ac18 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -24,6 +24,10 @@ import logging logger = logging.getLogger(__name__) +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -354,6 +358,50 @@ def create_network( in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: @@ -399,6 +447,8 @@ def create_network( train_t5xxl=train_t5xxl, type_dims=type_dims, in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, verbose=verbose, ) @@ -509,6 +559,8 @@ class LoRANetwork(torch.nn.Module): train_t5xxl: bool = False, type_dims: Optional[List[int]] = None, in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -527,6 +579,8 @@ class LoRANetwork(torch.nn.Module): self.type_dims = type_dims self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -600,7 +654,7 @@ class LoRANetwork(torch.nn.Module): dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha - if type_dims is not None: + if is_flux and type_dims is not None: identifier = [ ("img_attn",), ("txt_attn",), @@ -613,9 +667,33 @@ class LoRANetwork(torch.nn.Module): ] for i, d in enumerate(type_dims): if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d + dim = d # may be 0 for skip break + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha From 0cbe95bcc7e88f518802f29fe2b99da806963267 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:21:28 +0900 Subject: [PATCH 196/356] fix text_encoder_lr to work with int closes #1608 --- networks/lora_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index f549ac18..91e9cd77 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -966,8 +966,8 @@ class LoRANetwork(torch.nn.Module): # if float, use the same value for both text encoders if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): text_encoder_lr = [default_lr, default_lr] - elif isinstance(text_encoder_lr, float): - text_encoder_lr = [text_encoder_lr, text_encoder_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] elif len(text_encoder_lr) == 1: text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] From a2ad7e5644f08141fe053a2b63446d70d777bdcf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 17 Sep 2024 21:42:14 +0900 Subject: [PATCH 197/356] blocks_to_swap=0 means no swap --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index 33481df8..5d8326b1 100644 --- a/flux_train.py +++ b/flux_train.py @@ -265,7 +265,7 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! From bbd160b4ca9293881c222f9b9e1d832af69699db Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 18 Sep 2024 07:55:04 +0900 Subject: [PATCH 198/356] sd3 schedule free opt (#1605) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * New ScheduleFree support for Flux (#1600) * init * use no schedule * fix typo * update for eval() * fix typo * update * Update train_util.py * Update requirements.txt * update sfwrapper WIP * no need to check schedulefree optimizer * remove debug print * comment out schedulefree wrapper * update readme --------- Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com> --- README.md | 8 +++ library/train_util.py | 156 +++++++++++++++++++++++++++++++++++++++--- requirements.txt | 1 + 3 files changed, 156 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 24217d8b..dc986292 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,14 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024: + +- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - `schedulefree` is added to the dependencies. Please update the library if necessary. + - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. + - Wrapper classes are not available for now. + - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. + Sep 16, 2024: Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. diff --git a/library/train_util.py b/library/train_util.py index 60afd421..a54f23ff 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,6 +3303,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) + + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + if optimizer is None: # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split(".") - optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") - optimizer_class = getattr(optimizer_module, optimizer_type) + if "." not in case_sensitive_optimizer_type: # from torch.optim + optimizer_module = torch.optim + else: # from other library + values = case_sensitive_optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + case_sensitive_optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ + # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + schedulefree_wrapper_kwargs = {} + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + + sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) + sf_wrapper.train() # make optimizer as train mode + + # we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper + class OptimizerProxy(torch.optim.Optimizer): + def __init__(self, sf_wrapper): + self._sf_wrapper = sf_wrapper + + def __getattr__(self, name): + return getattr(self._sf_wrapper, name) + + # override properties + @property + def state(self): + return self._sf_wrapper.state + + @state.setter + def state(self, state): + self._sf_wrapper.state = state + + @property + def param_groups(self): + return self._sf_wrapper.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._sf_wrapper.param_groups = param_groups + + @property + def defaults(self): + return self._sf_wrapper.defaults + + @defaults.setter + def defaults(self, defaults): + self._sf_wrapper.defaults = defaults + + def add_param_group(self, param_group): + self._sf_wrapper.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self._sf_wrapper.load_state_dict(state_dict) + + def state_dict(self): + return self._sf_wrapper.state_dict() + + def zero_grad(self): + self._sf_wrapper.zero_grad() + + def step(self, closure=None): + self._sf_wrapper.step(closure) + + def train(self): + self._sf_wrapper.train() + + def eval(self): + self._sf_wrapper.eval() + + # isinstance チェックをパスするためのメソッド + def __instancecheck__(self, instance): + return isinstance(instance, (type(self), Optimizer)) + + optimizer = OptimizerProxy(sf_wrapper) + + logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) return optimizer_name, optimizer_args, optimizer +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler # Add some checking and features to the original function. @@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_warmup_steps: Optional[int] = ( diff --git a/requirements.txt b/requirements.txt index 9a4fa0c1..bab53f20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.43.3 prodigyopt==1.0 lion-pytorch==0.0.6 +schedulefree==1.2.7 tensorboard safetensors==0.4.4 # gradio==3.16.2 From e74502117bcf161ef5698fb0adba4f9fa0171b8d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 08:04:32 +0900 Subject: [PATCH 199/356] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index dc986292..034a260f 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ The command to install PyTorch is as follows: Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. + - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - `schedulefree` is added to the dependencies. Please update the library if necessary. - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - Wrapper classes are not available for now. From 1286e00bb0fc34c296f24b7057777f1c37cf8e11 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 18 Sep 2024 21:31:54 +0900 Subject: [PATCH 200/356] fix to call train/eval in schedulefree #1605 --- README.md | 3 +++ flux_train.py | 10 ++++++++++ library/train_util.py | 15 ++++++++++++++- train_network.py | 6 ++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 034a260f..843ae181 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 18, 2024 (update 1): +Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. + Sep 18, 2024: - Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. diff --git a/flux_train.py b/flux_train.py index 5d8326b1..bc4e6279 100644 --- a/flux_train.py +++ b/flux_train.py @@ -347,8 +347,13 @@ def train(args): logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -760,6 +765,7 @@ def train(args): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() flux_train_utils.sample_images( accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) @@ -778,6 +784,7 @@ def train(args): global_step, accelerator.unwrap_model(flux), ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -800,6 +807,7 @@ def train(args): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( @@ -816,12 +824,14 @@ def train(args): flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs ) + optimizer_train_fn() is_main_process = accelerator.is_main_process # if is_main_process: flux = accelerator.unwrap_model(flux) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) diff --git a/library/train_util.py b/library/train_util.py index a54f23ff..fe9deb94 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -13,6 +13,7 @@ import shutil import time from typing import ( Any, + Callable, Dict, List, NamedTuple, @@ -4715,8 +4716,20 @@ def get_optimizer(args, trainable_params): return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: - return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper def get_dummy_scheduler(optimizer: Optimizer) -> Any: diff --git a/train_network.py b/train_network.py index 34385ae0..55faa143 100644 --- a/train_network.py +++ b/train_network.py @@ -498,6 +498,7 @@ class NetworkTrainer: # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -1199,6 +1200,7 @@ class NetworkTrainer: progress_bar.update(1) global_step += 1 + optimizer_eval_fn() self.sample_images( accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet ) @@ -1217,6 +1219,7 @@ class NetworkTrainer: if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) @@ -1243,6 +1246,7 @@ class NetworkTrainer: accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1258,6 +1262,7 @@ class NetworkTrainer: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() # end of epoch @@ -1268,6 +1273,7 @@ class NetworkTrainer: network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) From e7040669bc9a31706fe9fedec14978b05223f968 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:47:06 +0900 Subject: [PATCH 201/356] Bug fix: alpha_mask load --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a46d9487..5a8da90e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2207,7 +2207,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph if alpha_mask: if "alpha_mask" not in npz: return False - if npz["alpha_mask"].shape[0:2] != reso: # HxW + if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso return False else: if "alpha_mask" in npz: From 9c757c2fba43d4e91d773cf6e9b7e2e8e3e8b376 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 19 Sep 2024 21:14:57 +0900 Subject: [PATCH 202/356] fix SDXL block index to match LBW --- networks/svd_merge_lora.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 0decd904..b4b9e3bf 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -184,18 +184,19 @@ def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: elif "mid_block_" in lora_name: block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block else: + # SDXL: some numbers are skipped if lora_name.startswith("lora_unet_"): name = lora_name[len("lora_unet_") :] if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts block_idx = 1 elif name.startswith("input_blocks_"): # 1-8 to 2-9 block_idx = 1 + int(name.split("_")[2]) - elif name.startswith("middle_block_"): # 10 - block_idx = 10 - elif name.startswith("output_blocks_"): # 0-8 to 11-19 - block_idx = 11 + int(name.split("_")[2]) - elif name.startswith("out_"): # 20, No LoRA in sd-scripts - block_idx = 20 + elif name.startswith("middle_block_"): # 13 + block_idx = 13 + elif name.startswith("output_blocks_"): # 0-8 to 14-22 + block_idx = 14 + int(name.split("_")[2]) + elif name.startswith("out_"): # 23, No LoRA in sd-scripts + block_idx = 23 return block_idx From 3957372ded6fda20553acaf169993a422b829bdc Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:30:03 -0700 Subject: [PATCH 203/356] Retain alpha in `pil_resize` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket. This codepath is entered on the last line, here: ``` def trim_and_resize_if_required( random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする if image_width > resized_size[0] and image_height > resized_size[1]: image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ else: image = pil_resize(image, resized_size) ``` --- library/utils.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index a0bb1965..2171c719 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,13 +305,26 @@ class MemoryEfficientSafeOpen: raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + # Check if the image has an alpha channel + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False - # use Pillow resize + if has_alpha: + # Convert BGRA to RGBA + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + # Convert BGR to RGB + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Resize the image resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + # Convert RGBA to BGRA + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + # Convert RGB to BGR + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From de4bb657b089cc28f4127e891b927895892e20b5 Mon Sep 17 00:00:00 2001 From: Ed McManus Date: Thu, 19 Sep 2024 14:38:32 -0700 Subject: [PATCH 204/356] Update utils.py Cleanup --- library/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/utils.py b/library/utils.py index 2171c719..8a0c782c 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,25 +305,19 @@ class MemoryEfficientSafeOpen: raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - # Check if the image has an alpha channel has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: - # Convert BGRA to RGBA pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) else: - # Convert BGR to RGB pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Resize the image resized_pil = pil_image.resize(size, interpolation) # Convert back to cv2 format if has_alpha: - # Convert RGBA to BGRA resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) else: - # Convert RGB to BGR resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From 0535cd29b926530255d5400374813432ec52c3df Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Fri, 20 Sep 2024 10:05:22 +0800 Subject: [PATCH 205/356] fix: backward compatibility for text_encoder_lr --- train_network.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 55faa143..dfa51a9c 100644 --- a/train_network.py +++ b/train_network.py @@ -471,7 +471,11 @@ class NetworkTrainer: if support_multiple_lrs: text_encoder_lr = args.text_encoder_lr else: - text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: if support_multiple_lrs: results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) From 583d4a436c1cef57fce405d0167fb7ce575fc768 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 20 Sep 2024 22:22:24 +0900 Subject: [PATCH 206/356] add compatibility for int LR (D-Adaptation etc.) #1620 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index dfa51a9c..b24f89b1 100644 --- a/train_network.py +++ b/train_network.py @@ -472,7 +472,7 @@ class NetworkTrainer: text_encoder_lr = args.text_encoder_lr else: # toml backward compatibility - if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float): + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): text_encoder_lr = args.text_encoder_lr else: text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] From e1f23af1bc733a1a89c35cf1be1301006c744b4a Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 21 Sep 2024 12:58:32 +0100 Subject: [PATCH 207/356] make timestep sampling behave in the standard way when huber loss is used --- library/train_util.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e..72d2d811 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timestep = timesteps.item() - if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = math.exp(-alpha * timestep) + huber_c = torch.exp(-alpha * timesteps) elif args.huber_schedule == "snr": - alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c elif args.huber_schedule == "constant": - huber_c = args.huber_c + huber_c = torch.full((b_size,), args.huber_c) else: raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - timesteps = timesteps.repeat(b_size).to(device) + huber_c = huber_c.to(device) elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - huber_c = 1 # may be anything, as it's not used + huber_c = None # may be anything, as it's not used else: raise NotImplementedError(f"Unknown loss type {args.loss_type}") - timesteps = timesteps.long() + timesteps = timesteps.long().to(device) return timesteps, huber_c @@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps, huber_c -# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 + model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] ): if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) From 29177d2f0389bd13e3f12c95d463fb0e1c58f9a1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 23 Sep 2024 21:14:03 +0900 Subject: [PATCH 208/356] retain alpha in pil_resize backport #1619 --- library/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index 5b7e657b..49d46a54 100644 --- a/library/utils.py +++ b/library/utils.py @@ -83,13 +83,20 @@ def setup_logging(args=None, log_level=None, reset=False): def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False + + if has_alpha: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # use Pillow resize resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From ab7b23187062db86d34fc82db95f7266a68ab5c4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 19:38:52 +0800 Subject: [PATCH 209/356] init --- library/train_util.py | 21 ++++++++++++++++++--- requirements.txt | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e..bdf7774e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2994,7 +2994,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -4032,7 +4032,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -4141,7 +4141,22 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - + elif optimizer_type == "Ademamix8bit".lower(): + logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.AdEMAMix8bit + except AttributeError: + raise AttributeError( + "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdemamix8bit".lower(): + logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdEMAMix8bit + except AttributeError: + raise AttributeError( + "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): diff --git a/requirements.txt b/requirements.txt index 15e6e58f..e6e1bf6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.1.1 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard From e74f58148c5994889463afa42bb6fc5d6447a75e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 25 Sep 2024 20:55:50 +0900 Subject: [PATCH 210/356] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index def528a2..9eabdaee 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! + - Improvements in OFT (Orthogonal Finetuning) Implementation 1. Optimization of Calculation Order: - Changed the calculation order in the forward method from (Wx)R to W(xR). From 1beddd84e5c4db729a84356db227d981dc18cf8d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 22:58:26 +0800 Subject: [PATCH 211/356] delete code for cleaning --- library/train_util.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bdf7774e..c4845c54 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4141,22 +4141,7 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - elif optimizer_type == "Ademamix8bit".lower(): - logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.AdEMAMix8bit - except AttributeError: - raise AttributeError( - "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - elif optimizer_type == "PagedAdemamix8bit".lower(): - logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdEMAMix8bit - except AttributeError: - raise AttributeError( - "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): From 56a7bc171d48089fb50f8638537e42d07c579db3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:26:31 +0900 Subject: [PATCH 212/356] new block swap for FLUX.1 fine tuning --- README.md | 47 ++++++-- flux_train.py | 241 +++++++++++++++++++++++++++-------------- library/flux_models.py | 172 ++++++++++++++++------------- 3 files changed, 294 insertions(+), 166 deletions(-) diff --git a/README.md b/README.md index ef691e91..7d623f90 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 26, 2024: +The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + + Sep 18, 2024 (update 1): Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. @@ -307,6 +311,8 @@ python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_ The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! +__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__ + Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. ``` @@ -319,39 +325,62 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ---fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16 +--fused_backward_pass --blocks_to_swap 8 --full_bf16 ``` (The command is multi-line for readability. Please combine it into one line.) -Options are almost the same as LoRA training. The difference is `--full_bf16`, `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available. +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available. `--full_bf16` enables the training with bf16 (weights and gradients). `--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. -`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. All these options are experimental and may change in the future. The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. -Swap 6 double blocks and use cpu offload checkpointing may be a good starting point. Please try different settings according to VRAM usage and training speed. +Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed. The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. +#### How to use block swap + +There are two possible ways to use block swap. It is unknown which is better. + +1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step. + + The above command example is for this usage. + +2. Swap many blocks to increase the batch size and shorten the training speed per data. + + For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + +#### Training with <24GB VRAM GPUs + +Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. + +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. + #### Key Features for FLUX.1 fine-tuning -1. Technical details of double/single block swap: +1. Technical details of block swap: - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. - Since the transfer between CPU and GPU takes time, the training will be slower. - - `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU. - - About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block. + - `--blocks_to_swap` specify the number of blocks to swap. + - About 640MB of memory can be saved per block. + - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. + - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. + - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. + - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. + - After the backward pass, the blocks are back to their original locations. 2. Sample Image Generation: - Sample image generation during training is now supported. diff --git a/flux_train.py b/flux_train.py index bc4e6279..bf34208f 100644 --- a/flux_train.py +++ b/flux_train.py @@ -11,10 +11,12 @@ # - Per-block fused optimizer instances import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os from multiprocessing import Value +import time from typing import List import toml @@ -265,14 +267,30 @@ def train(args): flux.requires_grad_(True) - is_swapping_blocks = args.double_blocks_to_swap or args.single_blocks_to_swap + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! - logger.info( - f"enable block swap: double_blocks_to_swap={args.double_blocks_to_swap}, single_blocks_to_swap={args.single_blocks_to_swap}" - ) - flux.enable_block_swap(args.double_blocks_to_swap, args.single_blocks_to_swap) + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap) if not cache_latents: # load VAE here if not cached @@ -443,82 +461,120 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def get_block_unit(dbl_blocks, sgl_blocks, index: int): + if index < len(dbl_blocks): + return (dbl_blocks[index],) + else: + index -= len(dbl_blocks) + index *= 2 + return (sgl_blocks[index], sgl_blocks[index + 1]) + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + for block in blocks_to_cpu: + block = block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") + for block in blocks_to_cuda: + block = block.to(dvc, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") + return bidx_to_cpu, bidx_to_cuda + + blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) + blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + # print(f"Backward: Wait for block {block_idx}") + # start_time = time.perf_counter() + future = futures.pop(block_idx) + future.result() + # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + # torch.cuda.synchronize() + # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) - handled_double_block_indices = set() - handled_single_block_indices = set() + num_block_units = num_double_blocks + num_single_blocks // 2 + handled_unit_indices = set() + + n = 1 # only asyncronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: grad_hook = None - if double_blocks_to_swap: - if param_name.startswith("double_blocks"): + if blocks_to_swap: + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if is_double or is_single: block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_double_block_indices - and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1 - and block_idx < num_double_blocks - 1 - ): - # swap next (already backpropagated) block - handled_double_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu) + unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 + if unit_idx not in handled_unit_indices: + # swap following (already backpropagated) block + handled_unit_indices.add(unit_idx) - # create swap hook - def create_double_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + # if n blocks were already backpropagated + num_blocks_propagated = num_block_units - unit_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = unit_idx - 1 - # swap blocks if necessary - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None - return __grad_hook + # print(f"Backward: {uidx}, {swpng}, {wtng}") + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) - grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda) - if single_blocks_to_swap: - if param_name.startswith("single_blocks"): - block_idx = int(param_name.split(".")[1]) - if ( - block_idx not in handled_single_block_indices - and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1 - and block_idx < num_single_blocks - 1 - ): - handled_single_block_indices.add(block_idx) - block_idx_cpu = block_idx + 1 - block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu) - # print(param_name, block_idx_cpu, block_idx_cuda) + return __grad_hook - # create swap hook - def create_single_swap_grad_hook(bidx, bidx_cuda): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # swap blocks if necessary - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") - - return __grad_hook - - grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda) + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + ) if grad_hook is None: @@ -547,10 +603,15 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - double_blocks_to_swap = args.double_blocks_to_swap - single_blocks_to_swap = args.single_blocks_to_swap + blocks_to_swap = args.blocks_to_swap num_double_blocks = 19 # len(flux.double_blocks) num_single_blocks = 38 # len(flux.single_blocks) + num_block_units = num_double_blocks + num_single_blocks // 2 + + n = 1 # only asyncronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: @@ -571,18 +632,30 @@ def train(args): optimizers[i].zero_grad(set_to_none=True) # swap blocks if necessary - if btype == "double" and double_blocks_to_swap: - if bidx >= num_double_blocks - double_blocks_to_swap: - bidx_cuda = double_blocks_to_swap - (num_double_blocks - bidx) - flux.double_blocks[bidx].to("cpu") - flux.double_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move double block {bidx} to cpu and {bidx_cuda} to device") - elif btype == "single" and single_blocks_to_swap: - if bidx >= num_single_blocks - single_blocks_to_swap: - bidx_cuda = single_blocks_to_swap - (num_single_blocks - bidx) - flux.single_blocks[bidx].to("cpu") - flux.single_blocks[bidx_cuda].to(accelerator.device) - # print(f"Move single block {bidx} to cpu and {bidx_cuda} to device") + if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): + unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 + num_blocks_propagated = num_block_units - unit_idx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + flux.double_blocks, + flux.single_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = unit_idx - 1 + wait_blocks_move(block_idx_to_wait, futures) return optimizer_hook @@ -881,24 +954,26 @@ def setup_parser() -> argparse.ArgumentParser: help="skip latents validity check / latentsの正当性チェックをスキップする", ) parser.add_argument( - "--double_blocks_to_swap", + "--blocks_to_swap", type=int, default=None, help="[EXPERIMENTAL] " - "Sets the number of 'double_blocks' (~640MB) to swap during the forward and backward passes." + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約640MB)の数を設定します。" + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) parser.add_argument( "--single_blocks_to_swap", type=int, default=None, - help="[EXPERIMENTAL] " - "Sets the number of 'single_blocks' (~320MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップする'変換ブロック'(約320MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", ) parser.add_argument( "--cpu_offload_checkpointing", diff --git a/library/flux_models.py b/library/flux_models.py index b5726c29..a35dbc10 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,9 +2,12 @@ # license: Apache-2.0 License +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass import math -from typing import Optional +import os +import time +from typing import Dict, List, Optional from library.device_utils import init_ipex, clean_memory_on_device @@ -917,8 +920,10 @@ class Flux(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - self.double_blocks_to_swap = None - self.single_blocks_to_swap = None + self.blocks_to_swap = None + + self.thread_pool: Optional[ThreadPoolExecutor] = None + self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 @property def device(self): @@ -956,38 +961,52 @@ class Flux(nn.Module): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]): - self.double_blocks_to_swap = double_blocks - self.single_blocks_to_swap = single_blocks + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + # n = 2 + # n = max(1, os.cpu_count() // 2) + self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu - if self.double_blocks_to_swap: + if self.blocks_to_swap: save_double_blocks = self.double_blocks - self.double_blocks = None - if self.single_blocks_to_swap: save_single_blocks = self.single_blocks + self.double_blocks = None self.single_blocks = None self.to(device) - if self.double_blocks_to_swap: + if self.blocks_to_swap: self.double_blocks = save_double_blocks - if self.single_blocks_to_swap: self.single_blocks = save_single_blocks + def get_block_unit(self, index: int): + if index < len(self.double_blocks): + return (self.double_blocks[index],) + else: + index -= len(self.double_blocks) + index *= 2 + return self.single_blocks[index], self.single_blocks[index + 1] + + def get_unit_index(self, is_double: bool, index: int): + if is_double: + return index + else: + return len(self.double_blocks) + index // 2 + def prepare_block_swap_before_forward(self): - # move last n blocks to cpu: they are on cuda - if self.double_blocks_to_swap: - for i in range(len(self.double_blocks) - self.double_blocks_to_swap): - self.double_blocks[i].to(self.device) - for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)): - self.double_blocks[i].to("cpu") # , non_blocking=True) - if self.single_blocks_to_swap: - for i in range(len(self.single_blocks) - self.single_blocks_to_swap): - self.single_blocks[i].to(self.device) - for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)): - self.single_blocks[i].to("cpu") # , non_blocking=True) + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None: + raise ValueError("Block swap is not enabled.") + for i in range(self.num_block_units - self.blocks_to_swap): + for b in self.get_block_unit(i): + b.to(self.device) + for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + for b in self.get_block_unit(i): + b.to("cpu") clean_memory_on_device(self.device) def forward( @@ -1017,69 +1036,73 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - if not self.double_blocks_to_swap: + if not self.blocks_to_swap: for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.double_blocks_to_swap): - block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.") - - block = self.double_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved double block {block_idx} to cuda.") - - to_cpu_block_index = 0 - for block_idx, block in enumerate(self.double_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved double block {block_idx} to cuda.") - - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - - if moving: - self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved double block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 - - img = torch.cat((txt, img), 1) - - if not self.single_blocks_to_swap: + img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # make sure first n blocks are on cuda, and last n blocks are on cpu at beginning - for block_idx in range(self.single_blocks_to_swap): - block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx] - if block.parameters().__next__().device.type != "cpu": - block.to("cpu") # , non_blocking=True) - # print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.") + futures = {} - block = self.single_blocks[block_idx] - if block.parameters().__next__().device.type == "cpu": - block.to(self.device) - # print(f"Moved single block {block_idx} to cuda.") + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + for block in blocks_to_cpu: + block.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + for block in blocks_to_cuda: + block.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) + blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + + for block_idx, block in enumerate(self.double_blocks): + # print(f"Double block {block_idx}") + unit_idx = self.get_unit_index(is_double=True, index=block_idx) + wait_for_blocks_move(unit_idx, futures) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + if unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future + + img = torch.cat((txt, img), 1) - to_cpu_block_index = 0 for block_idx, block in enumerate(self.single_blocks): - # move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda - moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap - if moving: - block.to(self.device) # move to cuda - # print(f"Moved single block {block_idx} to cuda.") + # print(f"Single block {block_idx}") + unit_idx = self.get_unit_index(is_double=False, index=block_idx) + if block_idx % 2 == 0: + wait_for_blocks_move(unit_idx, futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if moving: - self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True) - # print(f"Moved single block {to_cpu_block_index} to cpu.") - to_cpu_block_index += 1 + if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: + block_idx_to_cpu = unit_idx + block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] @@ -1088,6 +1111,7 @@ class Flux(nn.Module): vec = vec.to(self.device) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img From da94fd934eb4951d1cb132abc9d2a355e44d7abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 08:27:48 +0900 Subject: [PATCH 213/356] fix typos --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index bf34208f..022467ea 100644 --- a/flux_train.py +++ b/flux_train.py @@ -516,7 +516,7 @@ def train(args): num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = 2 # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) @@ -608,7 +608,7 @@ def train(args): num_single_blocks = 38 # len(flux.single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 - n = 1 # only asyncronous purpose, no need to increase this number + n = 1 # only asynchronous purpose, no need to increase this number # n = max(1, os.cpu_count() // 2) thread_pool = ThreadPoolExecutor(max_workers=n) futures = {} From bf91bea2e4363e5b3e0db11f0955ab93a19a0452 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 20:51:40 +0900 Subject: [PATCH 214/356] fix flip_aug, alpha_mask, random_crop issue in caching --- README.md | 2 ++ library/train_util.py | 44 +++++++++++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 9eabdaee..b67a2c4e 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. + - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! - Improvements in OFT (Orthogonal Finetuning) Implementation diff --git a/library/train_util.py b/library/train_util.py index 72d2d811..a31d00c6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -998,9 +998,26 @@ class BaseDataset(torch.utils.data.Dataset): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1021,28 +1038,31 @@ class BaseDataset(torch.utils.data.Dataset): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= vae_batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -2315,7 +2335,7 @@ def debug_dataset(train_dataset, show_input_ids=False): if "alpha_masks" in example and example["alpha_masks"] is not None: alpha_mask = example["alpha_masks"][j] logger.info(f"alpha mask size: {alpha_mask.size()}") - alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8) if os.name == "nt": cv2.imshow("alpha_mask", alpha_mask) @@ -5124,7 +5144,7 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.huber_schedule == "exponential": From 392e8dedd84e469b125e2935e3ecf02e6270a5b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:14:11 +0900 Subject: [PATCH 215/356] fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy --- library/train_util.py | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 319337a4..17dd447e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -993,9 +993,26 @@ class BaseDataset(torch.utils.data.Dataset): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1016,20 +1033,23 @@ class BaseDataset(torch.utils.data.Dataset): if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= caching_strategy.batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) # if cache to disk, don't cache latents in non-main process, set to info only if caching_strategy.cache_to_disk and not is_main_process: @@ -1041,9 +1061,8 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) - caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From a94bc84dec8e85e8a71217b4d2570a52c6779b73 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:31 +0900 Subject: [PATCH 216/356] fix to work bitsandbytes optimizers with full path #1640 --- library/train_util.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b40945ab..47c36768 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3014,7 +3014,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, " + "Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, " + "AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.", ) # backward compatibility @@ -4105,6 +4109,7 @@ def get_optimizer(args, trainable_params): lr = args.learning_rate optimizer = None + optimizer_class = None if optimizer_type == "Lion".lower(): try: @@ -4162,7 +4167,8 @@ def get_optimizer(args, trainable_params): "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") @@ -4338,6 +4344,7 @@ def get_optimizer(args, trainable_params): optimizer_class = getattr(optimizer_module, optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) From ce49ced699298aa885d9a64b969fe8c77f30893b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:40 +0900 Subject: [PATCH 217/356] update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b67a2c4e..9f024c1c 100644 --- a/README.md +++ b/README.md @@ -140,9 +140,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. - - transformers, accelerate and huggingface_hub are updated. + - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! + - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). + - Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! From 9249d00311002c84b189c2f6792cbe7aa344a1d5 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 26 Sep 2024 22:19:56 +0900 Subject: [PATCH 218/356] experimental support for multi-gpus latents caching --- library/train_util.py | 27 ++++++++++++++++----------- train_network.py | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3768b605..2ca662dc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -981,7 +981,7 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): r""" a brand new method to cache latents. This method caches latents with caching strategy. normal cache_latents method is used by default, but this method is used when caching strategy is specified. @@ -1013,8 +1013,12 @@ class BaseDataset(torch.utils.data.Dataset): batch: List[ImageInfo] = [] current_condition = None + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + logger.info("checking cache validity...") - for info in tqdm(image_infos): + for i, info in enumerate(tqdm(image_infos)): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: # fine tuning dataset @@ -1024,9 +1028,14 @@ class BaseDataset(torch.utils.data.Dataset): if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - if not is_main_process: # prepare for multi-gpu, only store to info + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: continue + print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask ) @@ -1051,10 +1060,6 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) > 0: batches.append((current_condition, batch)) - # if cache to disk, don't cache latents in non-main process, set to info only - if caching_strategy.cache_to_disk and not is_main_process: - return - if len(batches) == 0: logger.info("no latents to cache") return @@ -2258,8 +2263,8 @@ class ControlNetDataset(BaseDataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - def new_cache_latents(self, model: Any, is_main_process: bool): - return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) @@ -2363,10 +2368,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) - def new_cache_latents(self, model: Any, is_main_process: bool): + def new_cache_latents(self, model: Any, accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_latents(model, is_main_process) + dataset.new_cache_latents(model, accelerator) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True diff --git a/train_network.py b/train_network.py index b24f89b1..7eb7aa49 100644 --- a/train_network.py +++ b/train_network.py @@ -384,7 +384,7 @@ class NetworkTrainer: vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) From 24b1fdb66485af70b3c79feaf8ff1a348b66668e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Thu, 26 Sep 2024 22:22:06 +0900 Subject: [PATCH 219/356] remove debug print --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2ca662dc..8d6164b1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1031,10 +1031,10 @@ class BaseDataset(torch.utils.data.Dataset): # if the modulo of num_processes is not equal to process_index, skip caching # this makes each process cache different latents - if i % num_processes != process_index: + if i % num_processes != process_index: continue - print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") cache_available = caching_strategy.is_disk_cached_latents_expected( info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask From a9aa52658a0d9ba7910a1d1983b650bc9de7153e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 17:12:56 +0900 Subject: [PATCH 220/356] fix sample generation is not working in FLUX1 fine tuning #1647 --- library/flux_models.py | 5 +++-- library/flux_train_utils.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index a35dbc10..0bc1c02b 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -999,8 +999,9 @@ class Flux(nn.Module): def prepare_block_swap_before_forward(self): # make: first n blocks are on cuda, and last n blocks are on cpu - if self.blocks_to_swap is None: - raise ValueError("Block swap is not enabled.") + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return for i in range(self.num_block_units - self.blocks_to_swap): for b in self.get_block_unit(i): b.to(self.device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f77d4b58..1d1eb9d2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -313,6 +313,7 @@ def denoise( guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() pred = model( img=img, img_ids=img_ids, @@ -325,7 +326,8 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + + model.prepare_block_swap_before_forward() return img From 822fe578591e44ac949830e03a8841e222483052 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 28 Sep 2024 20:57:27 +0900 Subject: [PATCH 221/356] add workaround for 'Some tensors share memory' error #1614 --- networks/convert_flux_lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py index bd4c1cf7..fe6466eb 100644 --- a/networks/convert_flux_lora.py +++ b/networks/convert_flux_lora.py @@ -412,6 +412,10 @@ def main(args): state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) elif args.src == "sd-scripts" and args.dst == "ai-toolkit": state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() else: raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") From 1a0f5b0c389f4e9fab5edb06b36f203e8894d581 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 00:35:29 +0900 Subject: [PATCH 222/356] re-fix sample generation is not working in FLUX1 split mode #1647 --- flux_train_network.py | 3 +++ library/flux_train_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flux_train_network.py b/flux_train_network.py index a6e57eed..65b121e7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -300,6 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_lower = flux_lower self.target_device = device + def prepare_block_swap_before_forward(self): + pass + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d1eb9d2..b3c9184f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -196,7 +196,6 @@ def sample_image_inference( tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: From fe2aa32484a948f16955909e64c21da7fe1e4e0c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:49:25 +0900 Subject: [PATCH 223/356] adjust min/max bucket reso divisible by reso steps #1632 --- README.md | 2 ++ docs/config_README-en.md | 2 ++ docs/config_README-ja.md | 2 ++ fine_tune.py | 2 ++ library/train_util.py | 40 ++++++++++++++++++++++++++++++++------ train_controlnet.py | 2 ++ train_db.py | 2 ++ train_network.py | 2 +- train_textual_inversion.py | 2 +- 9 files changed, 48 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9f024c1c..de5cddb9 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) + - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 83bea329..66a50dc0 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d * `batch_size` * This corresponds to the command-line argument `--train_batch_size`. +* `max_bucket_reso`, `min_bucket_reso` + * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`. These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index cc74c341..0ed95e0e 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `batch_size` * コマンドライン引数の `--train_batch_size` と同等です。 +* `max_bucket_reso`, `min_bucket_reso` + * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。 これらの設定はデータセットごとに固定です。 つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 diff --git a/fine_tune.py b/fine_tune.py index d865cd2d..b556672d 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,6 +91,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/library/train_util.py b/library/train_util.py index 47c36768..0cb6383a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -653,6 +653,34 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + def adjust_min_max_bucket_reso_by_steps( + self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int + ) -> Tuple[int, int]: + # make min/max bucket reso to be multiple of bucket_reso_steps + if min_bucket_reso % bucket_reso_steps != 0: + adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps + logger.warning( + f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}" + ) + min_bucket_reso = adjusted_min_bucket_reso + if max_bucket_reso % bucket_reso_steps != 0: + adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps + logger.warning( + f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}" + ) + max_bucket_reso = adjusted_max_bucket_reso + + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + return min_bucket_reso, max_bucket_reso + def set_seed(self, seed): self.seed = seed @@ -1533,12 +1561,9 @@ class DreamBoothDataset(BaseDataset): self.enable_bucket = enable_bucket if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps @@ -1901,6 +1926,9 @@ class FineTuningDataset(BaseDataset): self.enable_bucket = enable_bucket if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a..6938c4bc 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -107,6 +107,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_db.py b/train_db.py index 39d8ea6e..2c7f0258 100644 --- a/train_db.py +++ b/train_db.py @@ -93,6 +93,8 @@ def train(args): if args.no_token_padding: train_dataset_group.disable_token_padding() + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_network.py b/train_network.py index 7ba07385..044ec3aa 100644 --- a/train_network.py +++ b/train_network.py @@ -95,7 +95,7 @@ class NetworkTrainer: return logs def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c3..96e7bd50 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -99,7 +99,7 @@ class TextualInversionTrainer: self.is_sdxl = False def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 1567549220b5936af0c534ca23656ecd2f4882f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:51:36 +0900 Subject: [PATCH 224/356] update help text #1632 --- library/train_util.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0cb6383a..422dceca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3865,8 +3865,20 @@ def add_dataset_arguments( action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--min_bucket_reso", + type=int, + default=256, + help="minimum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--max_bucket_reso", + type=int, + default=1024, + help="maximum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", + ) parser.add_argument( "--bucket_reso_steps", type=int, From e0c3630203776dc568c32d67806a0a9d443f5721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Sun, 29 Sep 2024 09:11:15 +0800 Subject: [PATCH 225/356] Support Sdxl Controlnet (#1648) * Create sdxl_train_controlnet.py * add fuse_background_pass * Update sdxl_train_controlnet.py * add fuse and fix error * update * Update sdxl_train_controlnet.py * Update sdxl_train_controlnet.py * Update sdxl_train_controlnet.py * update * Update sdxl_train_controlnet.py --- library/train_util.py | 2 +- sdxl_train_controlnet.py | 752 +++++++++++++++++++++++++++++++++++++++ train_controlnet.py | 31 +- 3 files changed, 778 insertions(+), 7 deletions(-) create mode 100644 sdxl_train_controlnet.py diff --git a/library/train_util.py b/library/train_util.py index e023f63a..293fc05a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3581,7 +3581,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") diff --git a/sdxl_train_controlnet.py b/sdxl_train_controlnet.py new file mode 100644 index 00000000..00026d2c --- /dev/null +++ b/sdxl_train_controlnet.py @@ -0,0 +1,752 @@ +import argparse +import math +import os +import random +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from diffusers.utils.torch_utils import is_compiled_module +from safetensors.torch import load_file +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, +) + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] + * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + + # convert U-Net + with torch.no_grad(): + du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) + unet.to("cpu") + clean_memory_on_device(accelerator.device) + del unet + unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) + unet.load_state_dict(du_unet_sd) + + controlnet = ControlNetModel.from_unet(unet) + + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # モデルに xformers とか memory efficient attention を組み込む + # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if args.xformers: + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info( + f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" + ) + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix( + args, optimizer, accelerator.num_processes + ) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + controlnet.to(weight_dtype) + unet.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + controlnet.to(weight_dtype) + unet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( + noise_scheduler + ) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + ( + "controlnet_train" + if args.log_tracker_name is None + else args.log_tracker_name + ), + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = ( + sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + ) + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file, sai_metadata) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = ( + batch["latents"] + .to(accelerator.device) + .to(dtype=weight_dtype) + ) + else: + # latentに変換 + latents = ( + vae.encode(batch["images"].to(dtype=vae_dtype)) + .latent_dist.sample() + .to(dtype=weight_dtype) + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print( + "NaN found in latents, replacing with zeros" + ) + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if ( + "text_encoder_outputs1_list" not in batch + or batch["text_encoder_outputs1_list"] is None + ): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + ) + else: + encoder_hidden_states1 = ( + batch["text_encoder_outputs1_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + encoder_hidden_states2 = ( + batch["text_encoder_outputs2_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + pool2 = ( + batch["text_encoder_pool2_list"] + .to(accelerator.device) + .to(weight_dtype) + ) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings( + # orig_size, crop_size, target_size, accelerator.device + # ).to(weight_dtype) + + embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 + # concat embeddings + #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + vector_embedding_dict = { + "text_embeds": pool2, + "time_ids": embs + } + text_embedding = torch.cat( + [encoder_hidden_states1, encoder_hidden_states2], dim=2 + ).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = ( + train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + ) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embedding, + added_cond_kwargs=vector_embedding_dict, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embedding, + added_cond_kwargs=vector_embedding_dict, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False, + )[0] + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = train_util.conditional_loss( + noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, + ) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name,unwrap_model(controlnet)) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name,unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + controlnet=controlnet, + ) + + # end of epoch + + if is_main_process: + controlnet = unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model( + ckpt_name, controlnet, force_sync_upload=True + ) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_controlnet.py b/train_controlnet.py index c2945b08..8c7882c8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -254,6 +254,7 @@ def train(args): accelerator.wait_for_everyone() if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() controlnet.enable_gradient_checkpointing() # 学習に必要なクラスを準備する @@ -304,6 +305,20 @@ def train(args): controlnet, optimizer, train_dataloader, lr_scheduler ) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + unet.requires_grad_(False) text_encoder.requires_grad_(False) unet.to(accelerator.device) @@ -497,13 +512,17 @@ def train(args): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From 8919b31145d38a2a790fae6e8e1c34c205c6794e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:07:34 +0900 Subject: [PATCH 226/356] use original ControlNet instead of Diffusers --- gen_img.py | 89 +++- library/sdxl_model_util.py | 2 +- library/sdxl_original_control_net.py | 272 ++++++++++++ library/sdxl_original_unet.py | 14 +- ...controlnet.py => sdxl_train_control_net.py | 386 ++++++++---------- 5 files changed, 526 insertions(+), 237 deletions(-) create mode 100644 library/sdxl_original_control_net.py rename sdxl_train_controlnet.py => sdxl_train_control_net.py (69%) diff --git a/gen_img.py b/gen_img.py index 59bcd5b0..70b3c81f 100644 --- a/gen_img.py +++ b/gen_img.py @@ -43,8 +43,8 @@ from diffusers import ( ) from einops import rearrange from tqdm import tqdm -from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +from accelerate import init_empty_weights import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -58,6 +58,7 @@ import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL @@ -352,8 +353,8 @@ class PipelineLike: self.token_replacements_list.append({}) # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] + self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.gradual_latent: GradualLatent = None @@ -542,7 +543,7 @@ class PipelineLike: else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - if self.control_net_lllites: + if self.control_net_lllites or (self.control_nets and self.is_sdxl): # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -731,7 +732,12 @@ class PipelineLike: num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if not self.is_sdxl: + guided_hints = original_control_net.get_guided_hints( + self.control_nets, num_latent_input, batch_size, clip_guide_images + ) + else: + clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1] each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) if self.control_net_lllites: @@ -793,7 +799,7 @@ class PipelineLike: latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + # disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo if self.control_net_lllites: for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): if not enabled or ratio >= 1.0: @@ -802,9 +808,16 @@ class PipelineLike: logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False + if self.control_nets and self.is_sdxl: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + each_control_net_enabled[j] = False # predict the noise residual - if self.control_nets and self.control_net_enabled: + if self.control_nets and self.control_net_enabled and not self.is_sdxl: if regional_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -823,6 +836,31 @@ class PipelineLike: text_embeddings, text_emb_last, ).sample + elif self.control_nets: + input_resi_add_list = [] + mid_add_list = [] + for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled): + if not enbld: + continue + input_resi_add, mid_add = control_net( + latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images + ) + input_resi_add_list.append(input_resi_add) + mid_add_list.append(mid_add) + if len(input_resi_add_list) == 0: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + if len(input_resi_add_list) > 1: + # get mean of input_resi_add_list and mid_add_list + input_resi_add_mean = [] + for i in range(len(input_resi_add_list[0])): + input_resi_add_mean.append( + torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0)) + ) + input_resi_add = input_resi_add_mean + mid_add = torch.mean(torch.stack(mid_add_list), dim=0) + + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) else: @@ -1827,16 +1865,37 @@ def main(args): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] + control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + if not is_sdxl: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + else: + for i, model_file in enumerate(args.control_net_models): + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + logger.info(f"loading SDXL ControlNet: {model_file}") + from safetensors.torch import load_file + + state_dict = load_file(model_file) + + logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}") + with init_empty_weights(): + control_net = SdxlControlNet(multiplier=multiplier) + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_nets.append((control_net, ratio)) control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 4fad78a1..0466c1fa 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -8,7 +8,7 @@ from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 00000000..3af45f4d --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a8..0aa07d0d 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel: self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. """ _self = self.delegate @@ -1209,6 +1209,8 @@ class InferSdxlUNet2DConditionModel: hs.append(h) h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add for module in _self.output_blocks: # Deep Shrink @@ -1217,7 +1219,11 @@ class InferSdxlUNet2DConditionModel: # print("upsample", h.shape, hs[-1].shape) h = resize_like(h, hs[-1]) - h = torch.cat([h, hs.pop()], dim=1) + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) h = call_module(module, h, emb, context) # Deep Shrink: in case of depth 0 diff --git a/sdxl_train_controlnet.py b/sdxl_train_control_net.py similarity index 69% rename from sdxl_train_controlnet.py rename to sdxl_train_control_net.py index 00026d2c..74dcff2a 100644 --- a/sdxl_train_controlnet.py +++ b/sdxl_train_control_net.py @@ -14,6 +14,7 @@ init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed +from accelerate import init_empty_weights from diffusers import DDPMScheduler, ControlNetModel from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file @@ -23,6 +24,9 @@ from library import ( sdxl_model_util, sdxl_original_unet, sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, ) import library.model_util as model_util @@ -41,6 +45,7 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet from library.utils import setup_logging, add_logging_arguments setup_logging() @@ -58,10 +63,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche } if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] - * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] return logs @@ -79,7 +81,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -106,17 +115,18 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) - ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) train_dataset_group.verify_bucket_reso_steps(32) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -162,86 +172,99 @@ def train(args): unet, logit_scale, ckpt_info, - ) = sdxl_train_util.load_target_model( - args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype - ) + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) - # convert U-Net - with torch.no_grad(): - du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict()) - unet.to("cpu") - clean_memory_on_device(accelerator.device) - del unet - unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG) - unet.load_state_dict(du_unet_sd) + unet.to(accelerator.device) # reduce main memory usage - controlnet = ControlNetModel.from_unet(unet) + # convert U-Net to Controlled U-Net + logger.info("convert U-Net to Controlled U-Net") + unet_sd = unet.state_dict() + with init_empty_weights(): + unet = SdxlControlledUNet() + unet.load_state_dict(unet_sd, strict=True, assign=True) + del unet_sd - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) + # make control net + logger.info("make ControlNet") + if args.controlnet_model_path: + with init_empty_weights(): + control_net = SdxlControlNet() + + logger.info(f"load ControlNet from {args.controlnet_model_path}") + filename = args.controlnet_model_path + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + info = control_net.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"ControlNet loaded from {filename}: {info}") + else: + control_net = SdxlControlNet() + + logger.info("initialize ControlNet from U-Net") + info = control_net.init_from_unet(unet) + logger.info(f"ControlNet initialized from U-Net: {info}") # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + accelerator.wait_for_everyone() # モデルに xformers とか memory efficient attention を組み込む # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) if args.xformers: - unet.enable_xformers_memory_efficient_attention() - controlnet.enable_xformers_memory_efficient_attention() + unet.set_use_memory_efficient_attention(True, False) + control_net.set_use_memory_efficient_attention(True, False) + elif args.sdpa: + unet.set_use_sdpa(True) + control_net.set_use_sdpa(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - controlnet.enable_gradient_checkpointing() + control_net.enable_gradient_checkpointing() # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters())) + trainable_params = list(control_net.parameters()) + # for p in trainable_params: + # p.requires_grad = True logger.info(f"trainable params count: {len(trainable_params)}") - logger.info( - f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}" - ) + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -257,7 +280,7 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) accelerator.print( f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" @@ -267,9 +290,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args, optimizer, accelerator.num_processes - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -277,19 +298,17 @@ def train(args): args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) + control_net.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") - controlnet.to(weight_dtype) - unet.to(weight_dtype) + control_net.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler + control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + control_net, optimizer, train_dataloader, lr_scheduler ) if args.fused_backward_pass: @@ -314,10 +333,8 @@ def train(args): text_encoder2.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() + unet.eval() + control_net.train() # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -362,26 +379,15 @@ def train(args): accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr( - noise_scheduler - ) + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} @@ -390,11 +396,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - ( - "controlnet_train" - if args.log_tracker_name is None - else args.log_tracker_name - ), + ("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name), config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs, ) @@ -409,10 +411,8 @@ def train(args): accelerator.print(f"\nsaving checkpoint: {ckpt_file}") sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = ( - sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" - ) - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + state_dict = model.state_dict() if save_dtype is not None: for key in list(state_dict.keys()): @@ -436,19 +436,19 @@ def train(args): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - # For --sample_at_first - sdxl_train_util.sample_images( - accelerator, - args, - 0, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) + # # For --sample_at_first + # sdxl_train_util.sample_images( + # accelerator, + # args, + # 0, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # unet, + # controlnet=control_net, + # ) # training loop for epoch in range(num_train_epochs): @@ -457,121 +457,63 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(controlnet): + with accelerator.accumulate(control_net): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = ( - batch["latents"] - .to(accelerator.device) - .to(dtype=weight_dtype) - ) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 - latents = ( - vae.encode(batch["images"].to(dtype=vae_dtype)) - .latent_dist.sample() - .to(dtype=weight_dtype) - ) + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): - accelerator.print( - "NaN found in latents, replacing with zeros" - ) + accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if ( - "text_encoder_outputs1_list" not in batch - or batch["text_encoder_outputs1_list"] is None - ): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = ( - train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = ( - batch["text_encoder_outputs1_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - encoder_hidden_states2 = ( - batch["text_encoder_outputs2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) - pool2 = ( - batch["text_encoder_pool2_list"] - .to(accelerator.device) - .to(weight_dtype) - ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] crop_size = batch["crop_top_lefts"] target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings( - # orig_size, crop_size, target_size, accelerator.device - # ).to(weight_dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6 # concat embeddings - #vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - vector_embedding_dict = { - "text_embeds": pool2, - "time_ids": embs - } - text_embedding = torch.cat( - [encoder_hidden_states1, encoder_hidden_states2], dim=2 - ).to(weight_dtype) + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = ( - train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents ) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - controlnet_cond=controlnet_image, - return_dict=False, + input_resi_add, mid_add = control_net( + noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states=text_embedding, - added_cond_kwargs=vector_embedding_dict, - down_block_additional_residuals=[ - sample.to(dtype=weight_dtype) for sample in down_block_res_samples - ], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - return_dict=False, - )[0] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add) if args.v_parameterization: # v-parameterization training @@ -580,7 +522,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c, + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) loss = loss.mean([1, 2, 3]) @@ -588,7 +530,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: @@ -601,7 +543,7 @@ def train(args): accelerator.backward(loss) if not args.fused_backward_pass: if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() + params_to_clip = control_net.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -616,25 +558,25 @@ def train(args): progress_bar.update(1) global_step += 1 - sdxl_train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - controlnet=controlnet, - ) + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # unet, + # controlnet=control_net, + # ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name,unwrap_model(controlnet)) + save_model(ckpt_name, unwrap_model(control_net)) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -650,14 +592,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -668,7 +610,7 @@ def train(args): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name,unwrap_model(controlnet)) + save_model(ckpt_name, unwrap_model(control_net)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -688,13 +630,13 @@ def train(args): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, - controlnet=controlnet, + controlnet=control_net, ) # end of epoch if is_main_process: - controlnet = unwrap_model(controlnet) + control_net = unwrap_model(control_net) accelerator.end_training() @@ -703,9 +645,7 @@ def train(args): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model( - ckpt_name, controlnet, force_sync_upload=True - ) + save_model(ckpt_name, control_net, force_sync_upload=True) logger.info("model saved.") @@ -717,26 +657,38 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) - train_util.add_masked_loss_arguments(parser) + # train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_sd_saving_arguments(parser) + # train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_name_or_path", + "--controlnet_model_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) parser.add_argument( "--no_half_vae", action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - return parser From 0243c65877a7700ffab1e782690f26080a0deadc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:09:56 +0900 Subject: [PATCH 227/356] fix typo --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 70b3c81f..421d5c0b 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1890,7 +1890,7 @@ def main(args): state_dict = load_file(model_file) - logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}") + logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}") with init_empty_weights(): control_net = SdxlControlNet(multiplier=multiplier) control_net.load_state_dict(state_dict) From 012e7e63a5b1acdf69c72eee4cb330a5a6defc41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:18:16 +0900 Subject: [PATCH 228/356] fix to work linear/cosine scheduler closes #1651 ref #1393 --- library/train_util.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 422dceca..27910dc9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4496,6 +4496,15 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") From 793999d116638548fc16579b712f44456ee3034e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 30 Sep 2024 23:39:32 +0900 Subject: [PATCH 229/356] sample generation in SDXL ControlNet training --- library/sdxl_lpw_stable_diffusion.py | 166 +++++++---------------- library/strategy_base.py | 192 ++++++++++++++++++++++++++- library/strategy_sdxl.py | 39 +++++- library/train_util.py | 35 +++-- sdxl_train_control_net.py | 55 ++++---- 5 files changed, 322 insertions(+), 165 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 03b18256..9196eb0f 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -13,12 +13,20 @@ from tqdm import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util, train_util +from library import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) try: @@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: vae: AutoencoderKL, text_encoder: List[CLIPTextModel], tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], scheduler: SchedulerMixin, # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, @@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, - controlnet=None, + controlnet: sdxl_original_control_net.SdxlControlNet = None, controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) - - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None unet_dtype = self.unet.dtype dtype = unet_dtype @@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) crop_size = torch.zeros_like(orig_size) target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) # make conditionings + text_pool = text_pool.to(device, dtype) if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # FIXME SD1 ControlNet is not working # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance diff --git a/library/strategy_base.py b/library/strategy_base.py index e7d3a97e..10820afa 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -1,6 +1,7 @@ # base class for platform strategies. this file defines the interface for strategies import os +import re from typing import Any, List, Optional, Tuple, Union import numpy as np @@ -22,6 +23,24 @@ logger = logging.getLogger(__name__) class TokenizeStrategy: _strategy = None # strategy instance: actual strategy class + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + @classmethod def set_strategy(cls, strategy): if cls._strategy is not None: @@ -54,7 +73,151 @@ class TokenizeStrategy: def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: raise NotImplementedError - def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor: + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: """ for SD1.5/2.0/SDXL TODO support batch input @@ -62,7 +225,10 @@ class TokenizeStrategy: if max_length is None: max_length = tokenizer.model_max_length - 2 - input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids if max_length > tokenizer.model_max_length: input_ids = input_ids.squeeze(0) @@ -101,6 +267,17 @@ class TokenizeStrategy: iids_list.append(ids_chunk) input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights return input_ids @@ -126,6 +303,17 @@ class TextEncodingStrategy: :return: list of output embeddings for each architecture """ raise NotImplementedError + + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError class TextEncoderOutputsCachingStrategy: diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 3eb0ab6f..b48e6d55 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -37,6 +37,22 @@ class SdxlTokenizeStrategy(TokenizeStrategy): torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), ) + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length) + tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), ( + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ) + class SdxlTextEncodingStrategy(TextEncodingStrategy): def __init__(self) -> None: @@ -98,7 +114,10 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy): ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] - max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 input_ids1 = input_ids1.to(text_encoder1.device) @@ -172,6 +191,24 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy): ) return [hidden_states1, hidden_states2, pool2] + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens) + + # apply weights + if weights[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1] + + return [hidden_states1, hidden_states2, pool2] + class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" diff --git a/library/train_util.py b/library/train_util.py index 293fc05a..b559616f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,6 +74,7 @@ import imagesize import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec @@ -3581,7 +3582,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"], + choices=[ + "eager", + "aot_eager", + "inductor", + "aot_ts_nvfuser", + "nvprims_nvfuser", + "cudagraphs", + "ofi", + "fx2trt", + "onnxrt", + "tensort", + "ipex", + "tvm", + ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") @@ -5850,8 +5864,8 @@ def sample_images_common( pipe_class, accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, + epoch: int, + steps: int, device, vae, tokenizer, @@ -5910,11 +5924,7 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) pipeline = pipe_class( text_encoder=text_encoder, @@ -5975,21 +5985,18 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - torch.set_rng_state(rng_state) if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - pipeline, + pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline], save_dir, prompt_dict, epoch, diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 74dcff2a..583a27dc 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -83,6 +83,7 @@ def train(args): tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( @@ -436,19 +437,19 @@ def train(args): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - # # For --sample_at_first - # sdxl_train_util.sample_images( - # accelerator, - # args, - # 0, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], - # unet, - # controlnet=control_net, - # ) + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) # training loop for epoch in range(num_train_epochs): @@ -484,7 +485,7 @@ def train(args): input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] + tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2] ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) @@ -558,18 +559,18 @@ def train(args): progress_bar.update(1) global_step += 1 - # sdxl_train_util.sample_images( - # accelerator, - # args, - # None, - # global_step, - # accelerator.device, - # vae, - # [tokenizer1, tokenizer2], - # [text_encoder1, text_encoder2], - # unet, - # controlnet=control_net, - # ) + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -628,7 +629,7 @@ def train(args): accelerator.device, vae, [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], unet, controlnet=control_net, ) From c2440f9e53239e7e5dee426f611800d3e38a7f0e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 3 Oct 2024 21:32:21 +0900 Subject: [PATCH 230/356] fix cond image normlization, add independent LR for control --- library/sdxl_train_util.py | 3 ++- library/train_util.py | 20 +++++++++++++++++++- sdxl_train_control_net.py | 30 +++++++++++++++++++++++++----- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f009b577..aaf77b8d 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,6 @@ from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging setup_logging() @@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/train_util.py b/library/train_util.py index b559616f..07c253a0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import hashlib import subprocess from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -912,6 +913,23 @@ class BaseDataset(torch.utils.data.Dataset): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -1826,7 +1844,7 @@ class DreamBoothDataset(BaseDataset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 583a27dc..b902cda6 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -253,11 +253,20 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(control_net.parameters()) - # for p in trainable_params: - # p.requires_grad = True - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -456,6 +465,8 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + control_net.train() + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(control_net): @@ -510,6 +521,9 @@ def train(args): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + with accelerator.autocast(): input_resi_add, mid_add = control_net( noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image @@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet / controlnetの学習率", + ) return parser From ba08a898940c80a6551111fdd77b53c6d3a019ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 4 Oct 2024 20:35:16 +0900 Subject: [PATCH 231/356] call optimizer eval/train for sample_at_first, also set train after resuming closes #1667 --- flux_train.py | 2 ++ train_network.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/flux_train.py b/flux_train.py index 022467ea..81c13e4c 100644 --- a/flux_train.py +++ b/flux_train.py @@ -706,7 +706,9 @@ def train(args): accelerator.unwrap_model(flux).prepare_block_swap_before_forward() # For --sample_at_first + optimizer_eval_fn() flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) diff --git a/train_network.py b/train_network.py index 7b2b76a1..f0d397b9 100644 --- a/train_network.py +++ b/train_network.py @@ -1042,7 +1042,9 @@ class NetworkTrainer: text_encoder = None # For --sample_at_first + optimizer_eval_fn() self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) From 83e3048cb089bf6726751609da26da751b8383ae Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 6 Oct 2024 21:32:21 +0900 Subject: [PATCH 232/356] load Diffusers format, check schnell/dev --- README.md | 4 + flux_minimal_inference.py | 15 +-- flux_train.py | 15 ++- flux_train_network.py | 17 ++- library/flux_utils.py | 178 +++++++++++++++++++++++++++-- tools/convert_diffusers_to_flux.py | 78 +------------ 6 files changed, 196 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 789fe514..c567758a 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 6, 2024: +- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. +- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. + Sep 26, 2024: The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 2f1b9a37..7ab224f1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -419,9 +419,6 @@ if __name__ == "__main__": steps = args.steps guidance_scale = args.guidance - name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way - is_schnell = name == "schnell" - def is_fp8(dt): return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] @@ -455,12 +452,8 @@ if __name__ == "__main__": # if is_fp8(t5xxl_dtype): # t5xxl = accelerator.prepare(t5xxl) - t5xxl_max_length = 256 if is_schnell else 512 - tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) - encoding_strategy = strategy_flux.FluxTextEncodingStrategy() - # DiT - model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype @@ -469,8 +462,12 @@ if __name__ == "__main__": # if args.offload: # model = model.to("cpu") + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + # AE - ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) + ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) ae.eval() # if is_fp8(ae_dtype): # ae = accelerator.prepare(ae) diff --git a/flux_train.py b/flux_train.py index 81c13e4c..ecc87c0a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,6 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -144,9 +145,8 @@ def train(args): args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" t5xxl_max_token_length = ( - args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) ) strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) @@ -177,12 +177,11 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -196,7 +195,7 @@ def train(args): # prepare tokenize strategy if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 @@ -258,8 +257,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - flux = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) if args.gradient_checkpointing: @@ -294,7 +293,7 @@ def train(args): if not cache_latents: # load VAE here if not cached - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") ae.requires_grad_(False) ae.eval() ae.to(accelerator.device, dtype=weight_dtype) diff --git a/flux_train_network.py b/flux_train_network.py index 65b121e7..5d14bd28 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import argparse import copy import math import random -from typing import Any +from typing import Any, Optional import torch from accelerate import Accelerator @@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -57,19 +58,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): train_dataset_group.verify_bucket_reso_steps(32) # TODO check this - def get_flux_model_name(self, args): - return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" - def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = self.get_flux_model_name(args) # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - model = flux_utils.load_flow_model( - name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors ) if args.fp8_base: # check dtype of model @@ -100,7 +97,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): elif t5xxl.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 T5XXL model") - ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model @@ -142,10 +139,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_lower def get_tokenize_strategy(self, args): - name = self.get_flux_model_name(args) + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: - if name == "schnell": + if is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b0a41a8..713814e2 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,9 +1,11 @@ import json -from typing import Optional, Union +import os +from typing import List, Optional, Tuple, Union import einops import torch from safetensors.torch import load_file +from safetensors import safe_open from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config @@ -17,6 +19,8 @@ import logging logger = logging.getLogger(__name__) MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" # temporary copy from sd3_utils TODO refactor @@ -39,10 +43,35 @@ def load_safetensors( return load_file(path) # prevent device invalid Error +def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: + # check the state dict: Diffusers or BFL, dev or schnell + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + return is_diffusers, is_schnell, ckpt_paths + + def load_flow_model( - name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False -) -> flux_models.Flux: - logger.info(f"Building Flux model {name}") + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, flux_models.Flux]: + is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): model = flux_models.Flux(flux_models.configs[name].params) if dtype is not None: @@ -50,18 +79,28 @@ def load_flow_model( # load_sft doesn't support torch.device logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd) + logger.info("Converted Diffusers to BFL") + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") - return model + return is_schnell, model def load_ae( - name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): - ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) @@ -246,3 +285,126 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: """ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(NUM_DOUBLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(NUM_SINGLE_BLOCKS): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map() + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py index 9d8f7c74..65ba7321 100644 --- a/tools/convert_diffusers_to_flux.py +++ b/tools/convert_diffusers_to_flux.py @@ -29,6 +29,7 @@ from safetensors.torch import safe_open import torch from tqdm import tqdm +from library import flux_utils from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file setup_logging() @@ -36,65 +37,6 @@ import logging logger = logging.getLogger(__name__) -NUM_DOUBLE_BLOCKS = 19 -NUM_SINGLE_BLOCKS = 38 - -BFL_TO_DIFFUSERS_MAP = { - "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], - "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], - "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], - "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], - "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], - "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], - "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], - "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], - "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], - "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], - "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], - "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], - "txt_in.weight": ["context_embedder.weight"], - "txt_in.bias": ["context_embedder.bias"], - "img_in.weight": ["x_embedder.weight"], - "img_in.bias": ["x_embedder.bias"], - "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], - "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], - "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], - "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], - "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], - "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], - "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], - "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], - "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], - "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], - "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], - "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], - "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], - "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], - "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], - "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], - "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], - "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], - "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], - "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], - "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], - "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], - "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], - "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], - "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], - "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], - "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], - "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], - "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], - "single_blocks.().linear2.weight": ["proj_out.weight"], - "single_blocks.().linear2.bias": ["proj_out.bias"], - "final_layer.linear.weight": ["proj_out.weight"], - "final_layer.linear.bias": ["proj_out.bias"], - "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], - "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], -} - def convert(args): # if diffusers_path is folder, get safetensors file @@ -114,23 +56,7 @@ def convert(args): save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None # make reverse map from diffusers map - diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("double_blocks."): - block_prefix = f"transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if key.startswith("single_blocks."): - block_prefix = f"single_transformer_blocks.{b}." - for i, weight in enumerate(weights): - diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for key, weights in BFL_TO_DIFFUSERS_MAP.items(): - if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): - for i, weight in enumerate(weights): - diffusers_to_bfl_map[weight] = (i, key) + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map() # iterate over three safetensors files to reduce memory usage flux_sd = {} From 886f75345c95cddec8752ffdd4e60a471ee75403 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 10 Oct 2024 08:27:15 +0900 Subject: [PATCH 233/356] support weighted captions for sdxl LoRA and fine tuning --- library/strategy_base.py | 5 ++++- library/strategy_sdxl.py | 3 ++- sdxl_train.py | 38 ++++++++++++++++++++------------------ sdxl_train_control_net.py | 7 ++----- train_network.py | 27 +++++++++++++++++---------- 5 files changed, 45 insertions(+), 35 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 10820afa..7981bd0b 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -74,6 +74,9 @@ class TokenizeStrategy: raise NotImplementedError def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ raise NotImplementedError def _get_weighted_input_ids( @@ -303,7 +306,7 @@ class TextEncodingStrategy: :return: list of output embeddings for each architecture """ raise NotImplementedError - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] ) -> List[torch.Tensor]: diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index b48e6d55..6650e2b4 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -174,7 +174,8 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy): """ Args: tokenize_strategy: TokenizeStrategy - models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)] + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required tokens: List of tokens, for text_encoder1 and text_encoder2 """ if len(models) == 2: diff --git a/sdxl_train.py b/sdxl_train.py index 7291ddd2..320169d7 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -104,8 +104,8 @@ def train(args): setup_logging(args, reset=True) assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + not args.weighted_captions or not args.cache_text_encoder_outputs + ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -660,22 +660,24 @@ def train(args): input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( - tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] - ) + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + input_ids_list, + weights_list, + ) + ) + else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + [input_ids1, input_ids2], + ) if args.full_fp16: encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index b902cda6..f6cc5a4f 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -12,24 +12,21 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from accelerate import init_empty_weights -from diffusers import DDPMScheduler, ControlNetModel +from diffusers import DDPMScheduler from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import load_file from library import ( deepspeed_utils, sai_model_spec, sdxl_model_util, - sdxl_original_unet, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, ) -import library.model_util as model_util import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -264,7 +261,7 @@ def train(args): trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) trainable_params.append({"params": unet_params, "lr": args.learning_rate}) all_params = ctrlnet_params + unet_params - + logger.info(f"trainable params count: {len(all_params)}") logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") diff --git a/train_network.py b/train_network.py index f0d397b9..e48e6a07 100644 --- a/train_network.py +++ b/train_network.py @@ -1123,14 +1123,21 @@ class NetworkTrainer: with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # SD only - encoded_text_encoder_conds = get_weighted_text_embeddings( - tokenizers[0], - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + # # SD only + # encoded_text_encoder_conds = get_weighted_text_embeddings( + # tokenizers[0], + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] @@ -1139,8 +1146,8 @@ class NetworkTrainer: self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) - if args.full_fp16: - encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: From 3de42b6edb151b172f483aec99fe380b1406a84a Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 10 Oct 2024 14:03:59 +0800 Subject: [PATCH 234/356] fix: distributed training in windows --- library/train_util.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e023f63a..3dabf9e2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5045,17 +5045,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - ( - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph - ) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None - ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + kwargs_handlers = [ + InitProcessGroupKwargs( + backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False" if os.name == "nt" else None, + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None + ) if torch.cuda.device_count() > 1 else None, + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, + static_graph=args.ddp_static_graph + ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( From 9f4dac5731fe2299c75b7671c6132febd57a4117 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 10 Oct 2024 14:08:55 +0800 Subject: [PATCH 235/356] torch 2.4 --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 3dabf9e2..2c20a924 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -33,6 +33,7 @@ from io import BytesIO import toml from tqdm import tqdm +from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device @@ -5048,7 +5049,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [ InitProcessGroupKwargs( backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" else None, + init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None ) if torch.cuda.device_count() > 1 else None, DistributedDataParallelKwargs( From f2bc8201330d1370c182c57047a5c23e9c6bee71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Oct 2024 08:48:55 +0900 Subject: [PATCH 236/356] support weighted captions for SD/SDXL --- fine_tune.py | 17 ++++-------- library/sdxl_train_util.py | 6 ++-- library/strategy_base.py | 12 +++++++- library/strategy_sd.py | 36 ++++++++++++++++++++++++ library/strategy_sdxl.py | 57 ++++++++++++++++++++++++++------------ sdxl_train.py | 2 +- sdxl_train_network.py | 4 ++- train_db.py | 16 ++++------- 8 files changed, 105 insertions(+), 45 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 62a545a1..fd63385b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -366,22 +366,17 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - # TODO move to strategy_sd.py - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index aaf77b8d..dc3887c3 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -363,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # ) # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: diff --git a/library/strategy_base.py b/library/strategy_base.py index 7981bd0b..2bff4178 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -323,12 +323,18 @@ class TextEncoderOutputsCachingStrategy: _strategy = None # strategy instance: actual strategy class def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check self._is_partial = is_partial + self._is_weighted = is_weighted @classmethod def set_strategy(cls, strategy): @@ -352,6 +358,10 @@ class TextEncoderOutputsCachingStrategy: def is_partial(self): return self._is_partial + @property + def is_weighted(self): + return self._is_weighted + def get_outputs_npz_path(self, image_abs_path: str) -> str: raise NotImplementedError diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31..4e7931fd 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,6 +40,16 @@ class SdTokenizeStrategy(TokenizeStrategy): text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: @@ -58,6 +68,8 @@ class SdTextEncodingStrategy(TextEncodingStrategy): model_max_length = sd_tokenize_strategy.tokenizer.model_max_length tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + tokens = tokens.to(text_encoder.device) + if self.clip_skip is None: encoder_hidden_states = text_encoder(tokens)[0] else: @@ -93,6 +105,30 @@ class SdTextEncodingStrategy(TextEncodingStrategy): return [encoder_hidden_states] + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 6650e2b4..6b3e2afa 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -42,16 +42,16 @@ class SdxlTokenizeStrategy(TokenizeStrategy): tokens1_list, tokens2_list = [], [] weights1_list, weights2_list = [], [] for t in text: - tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length) - tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length) + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) tokens1_list.append(tokens1) tokens2_list.append(tokens2) weights1_list.append(weights1) weights2_list.append(weights2) - return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), ( + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ torch.stack(weights1_list, dim=0), torch.stack(weights2_list, dim=0), - ) + ] class SdxlTextEncodingStrategy(TextEncodingStrategy): @@ -193,20 +193,28 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy): return [hidden_states1, hidden_states2, pool2] def encode_tokens_with_weights( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: - hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens) + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] # apply weights - if weights[0].shape[1] == 1: # no max_token_length + if weights_list[0].shape[1] == 1: # no max_token_length # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) - hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2) - hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) else: # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) - for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]): + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): for i in range(weight.shape[1]): - hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1] + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) return [hidden_states1, hidden_states2, pool2] @@ -215,9 +223,14 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -253,11 +266,19 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy captions = [info.caption for info in infos] - tokens1, tokens2 = tokenize_strategy.tokenize(captions) - with torch.no_grad(): - hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [tokens1, tokens2] - ) + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: hidden_state1 = hidden_state1.float() if hidden_state2.dtype == torch.bfloat16: diff --git a/sdxl_train.py b/sdxl_train.py index 320169d7..aeff9c46 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -321,7 +321,7 @@ def train(args): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4d6e3f18..20e32155 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -79,7 +79,9 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + ) else: return None diff --git a/train_db.py b/train_db.py index a5d520b1..e49a7e70 100644 --- a/train_db.py +++ b/train_db.py @@ -356,21 +356,17 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified From 035c4a8552bf6214ad4d39657d3eb1204cdecdfd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 11 Oct 2024 22:23:15 +0900 Subject: [PATCH 237/356] update docs and help text --- README.md | 10 ++++++++++ docs/train_lllite_README.md | 2 +- sdxl_train_control_net.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c567758a..d3f49c99 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 11, 2024: +- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. + - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). + - The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`. + - If you want to generate sample images, specify the control image as `--cn path/to/control/image`. + - The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI. +- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`. + - The default is `False`. It is same as before, and the parentheses are used as normal text. + - If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`. + Oct 6, 2024: - In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. - FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index a05f87f5..1bd8e4ae 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -185,7 +185,7 @@ for img_file in img_files: ### Creating a dataset configuration file -You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. +You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. ```toml [general] diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index f6cc5a4f..67c8d52c 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -705,7 +705,7 @@ def setup_parser() -> argparse.ArgumentParser: "--control_net_lr", type=float, default=1e-4, - help="learning rate for controlnet / controlnetの学習率", + help="learning rate for controlnet modules / controlnetモジュールの学習率", ) return parser From 0d3058b65ab7cd827e44f16f84c68a4bb73f701e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 14:46:35 +0900 Subject: [PATCH 238/356] update README --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index d3f49c99..37fc911f 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,17 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024: + +- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! + - It should work with all training scripts, but it is unverified. + - Set up multi-GPU training with `accelerate config`. + - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. + ``` + accelerate launch --rdzv_backend=c10d sdxl_train_network.py ... + ``` + - In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed. + Oct 11, 2024: - ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). From c80c304779775f4d00fd8f4856bfc8e6599e2de0 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 20:18:41 +0900 Subject: [PATCH 239/356] Refactor caching in train scripts --- README.md | 10 +++++ fine_tune.py | 2 +- flux_train.py | 14 ++++--- flux_train_network.py | 6 +-- library/train_util.py | 64 +++++++++++++++++++++++--------- sd3_train.py | 17 +++++++-- sdxl_train.py | 4 +- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 5 +-- sdxl_train_network.py | 8 ++-- sdxl_train_textual_inversion.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 14 files changed, 95 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 37fc911f..2b256283 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. +- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. +- `--skip_cache_check` option is added to each training script. + - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. + - Specify this option if you have a large number of cache files and the consistency check takes time. + - Even if this option is specified, the cache will be created if the file does not exist. + - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/fine_tune.py b/fine_tune.py index fd63385b..cdc005d9 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/flux_train.py b/flux_train.py index ecc87c0a..e18a9244 100644 --- a/flux_train.py +++ b/flux_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -81,7 +85,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -142,7 +146,7 @@ def train(args): if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False ) ) t5xxl_max_token_length = ( @@ -181,7 +185,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -229,7 +233,7 @@ def train(args): strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) parser.add_argument( "--blocks_to_swap", diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28..3bd8316d 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_flux.FluxTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, - None, - False, + args.text_encoder_batch_size, + args.skip_cache_check, is_partial=self.train_clip_l or self.train_t5xxl, apply_t5_attn_mask=args.apply_t5_attn_mask, ) @@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders[1].to(weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) # cache sample prompts if args.sample_prompts is not None: diff --git a/library/train_util.py b/library/train_util.py index 67eaae41..4e6b3408 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import hashlib import subprocess from io import BytesIO import toml + # from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): r""" a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. """ @@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset): # split by resolution batches = [] batch = [] - logger.info("checking cache validity...") - for info in tqdm(image_infos): - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - # check disk cache exists and size of latents + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs if caching_strategy.cache_to_disk: - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available or not is_main_process: # do not add to batch + if cache_available: # do not add to batch continue batch.append(info) @@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset): tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size ) - def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.new_cache_text_encoder_outputs(models, is_main_process) + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() def set_caching_mode(self, caching_mode): for dataset in self.datasets: @@ -4210,6 +4223,12 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", @@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend = args.dynamo_backend kwargs_handlers = [ - InitProcessGroupKwargs( - backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, - timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None - ) if torch.cuda.device_count() > 1 else None, - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, - static_graph=args.ddp_static_graph - ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ] kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) diff --git a/sd3_train.py b/sd3_train.py index 5120105f..7290956a 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -57,6 +57,10 @@ def train(args): deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + assert ( not args.weighted_captions ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -103,7 +107,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -312,7 +316,7 @@ def train(args): text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, - False, + args.skip_cache_check, train_clip_g or train_clip_l or args.use_t5xxl_cache_only, args.apply_lg_attn_mask, args.apply_t5_attn_mask, @@ -325,7 +329,7 @@ def train(args): t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) # cache sample prompt's embeddings to free text encoder's memory if args.sample_prompts is not None: @@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_latents_validity_check", action="store_true", - help="skip latents validity check / latentsの正当性チェックをスキップする", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", ) parser.add_argument( "--num_last_block_to_freeze", diff --git a/sdxl_train.py b/sdxl_train.py index aeff9c46..9b2d1916 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -131,7 +131,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. if args.cache_latents: latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -328,7 +328,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 67c8d52c..74b3a64a 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -84,7 +84,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -230,7 +230,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9d1cfc63..14ff7c24 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -93,7 +93,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) @@ -202,7 +202,7 @@ def train(args): text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) accelerator.wait_for_everyone() @@ -431,7 +431,6 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Text Encoder outputs are cached diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 20e32155..4a16a489 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy @@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions ) else: return None @@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs( - text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index cbfcef55..821a6955 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -49,7 +49,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_db.py b/train_db.py index e49a7e70..683b4233 100644 --- a/train_db.py +++ b/train_db.py @@ -64,7 +64,7 @@ def train(args): # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - False, args.cache_latents_to_disk, args.vae_batch_size, False + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) diff --git a/train_network.py b/train_network.py index 7437157b..d5330aef 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ class NetworkTrainer: def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3b3d3393..4d8a3abb 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -114,7 +114,7 @@ class TextualInversionTrainer: def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( - True, args.cache_latents_to_disk, args.vae_batch_size, False + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy From ecaea909b10fa8b3eb94a1cf57b26d5daba1683e Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 12 Oct 2024 20:26:57 +0900 Subject: [PATCH 240/356] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 37fc911f..9128bf8d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The command to install PyTorch is as follows: Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - It should work with all training scripts, but it is unverified. + - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - Set up multi-GPU training with `accelerate config`. - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. ``` From e277b5789e791539b5e51187530f11bd94e24871 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 12 Oct 2024 21:49:07 +0900 Subject: [PATCH 241/356] Update FLUX.1 support for compact models --- README.md | 10 ++++++ flux_train.py | 12 +++---- flux_train_network.py | 2 +- library/flux_utils.py | 76 ++++++++++++++++++++++++++++++++++++------- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 9128bf8d..b64515a1 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,16 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 12, 2024 (update 1): + +- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. + - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. + - The model is automatically determined based on the keys in *.safetensors. + - Specifications for compact model safetensors: + - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. + - LoRA training is unverified. + - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. + Oct 12, 2024: - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! diff --git a/flux_train.py b/flux_train.py index ecc87c0a..2fc13068 100644 --- a/flux_train.py +++ b/flux_train.py @@ -137,7 +137,7 @@ def train(args): train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( @@ -181,7 +181,7 @@ def train(args): # load VAE for caching latents ae = None if cache_latents: - ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae.to(accelerator.device, dtype=weight_dtype) ae.requires_grad_(False) ae.eval() @@ -510,8 +510,8 @@ def train(args): library.adafactor_fused.patch_adafactor_fused(optimizer) blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 handled_unit_indices = set() @@ -603,8 +603,8 @@ def train(args): parameter_optimizer_map = {} blocks_to_swap = args.blocks_to_swap - num_double_blocks = 19 # len(flux.double_blocks) - num_single_blocks = 38 # len(flux.single_blocks) + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) num_block_units = num_double_blocks + num_single_blocks // 2 n = 1 # only asynchronous purpose, no need to increase this number diff --git a/flux_train_network.py b/flux_train_network.py index 5d14bd28..a24c1905 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -139,7 +139,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_lower def get_tokenize_strategy(self, args): - _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) if args.t5xxl_max_token_length is None: if is_schnell: diff --git a/library/flux_utils.py b/library/flux_utils.py index 713814e2..7a1ec37b 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,3 +1,4 @@ +from dataclasses import replace import json import os from typing import List, Optional, Tuple, Union @@ -43,8 +44,21 @@ def load_safetensors( return load_file(path) # prevent device invalid Error -def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]: - # check the state dict: Diffusers or BFL, dev or schnell +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers @@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) - return is_diffusers, is_schnell, ckpt_paths + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths def load_flow_model( ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ) -> Tuple[bool, flux_models.Flux]: - is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path) + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL # build model logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): - model = flux_models.Flux(flux_models.configs[name].params) + params = flux_models.configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = flux_models.Flux(params) if dtype is not None: model = model.to(dtype) @@ -86,7 +138,7 @@ def load_flow_model( # convert Diffusers to BFL if is_diffusers: logger.info("Converting Diffusers to BFL") - sd = convert_diffusers_sd_to_bfl(sd) + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") info = model.load_state_dict(sd, strict=False, assign=True) @@ -349,16 +401,16 @@ BFL_TO_DIFFUSERS_MAP = { } -def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: # make reverse map from diffusers map diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) - for b in range(NUM_DOUBLE_BLOCKS): + for b in range(num_double_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("double_blocks."): block_prefix = f"transformer_blocks.{b}." for i, weight in enumerate(weights): diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) - for b in range(NUM_SINGLE_BLOCKS): + for b in range(num_single_blocks): for key, weights in BFL_TO_DIFFUSERS_MAP.items(): if key.startswith("single_blocks."): block_prefix = f"single_transformer_blocks.{b}." @@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]: return diffusers_to_bfl_map -def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - diffusers_to_bfl_map = make_diffusers_to_bfl_map() +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) # iterate over three safetensors files to reduce memory usage flux_sd = {} From 74228c9953b4ba0f8b0d68e8f6c8a8a6a469c2f5 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 16:27:22 +0900 Subject: [PATCH 242/356] update cache_latents/text_encoder_outputs --- library/strategy_base.py | 2 +- tools/cache_latents.py | 147 +++++++++++------------ tools/cache_text_encoder_outputs.py | 176 ++++++++++++++++------------ 3 files changed, 165 insertions(+), 160 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 2bff4178..363996ce 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -325,7 +325,7 @@ class TextEncoderOutputsCachingStrategy: def __init__( self, cache_to_disk: bool, - batch_size: int, + batch_size: Optional[int], skip_disk_cache_validity_check: bool, is_partial: bool = False, is_weighted: bool = False, diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 2f0098b4..d8154ec3 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ from accelerate.utils import set_seed import torch from tqdm import tqdm -from library import config_util +from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -17,42 +17,73 @@ from library.config_util import ( BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging logger = logging.getLogger(__name__) +def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None: + if is_flux: + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + else: + is_schnell = False + + if is_sd or is_sdxl: + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + elif is_sdxl: + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + else: + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + args.cache_latents = True + args.cache_latents_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + if is_sd or is_sdxl: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -83,17 +114,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -106,72 +131,27 @@ def cache_to_disk(args: argparse.Namespace) -> None: # モデルを読み込む logger.info("load model") - if args.sdxl: + if is_sd: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + elif is_sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if is_sd or is_sdxl: + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - alpha_mask = batch["alpha_mask"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected( - image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask - ): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) + # cache latents with dataset + # TODO use DataLoader to speed up + train_dataset_group.new_cache_latents(vae, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching latents to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -182,7 +162,11 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) + parser.add_argument( + "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" + ) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( "--no_half_vae", action="store_true", @@ -191,7 +175,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index a75d9da7..d294d46c 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -9,55 +9,68 @@ from accelerate.utils import set_seed import torch from tqdm import tqdm -from library import config_util +from library import ( + config_util, + flux_train_utils, + flux_utils, + sdxl_model_util, + strategy_base, + strategy_flux, + strategy_sd, + strategy_sdxl, +) from library import train_util from library import sdxl_train_util +from library import utils from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments +from tools import cache_latents + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + assert ( + is_sdxl or is_flux + ), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です" + assert ( + is_sdxl or args.weighted_captions is None + ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" + + cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -88,15 +101,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") @@ -105,66 +114,68 @@ def cache_to_disk(args: argparse.Namespace) -> None: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) + t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype) # モデルを読み込む logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + if is_sdxl: + _, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + text_encoder1.to(accelerator.device, weight_dtype) + text_encoder2.to(accelerator.device, weight_dtype) text_encoders = [text_encoder1, text_encoder2] else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] + clip_l = flux_utils.load_clip_l( + args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors + ) + + t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors) + + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + if t5xxl_dtype != t5xxl_dtype: + if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2: + logger.warning( + "The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop." + " / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。" + ) + logger.info(f"Casting T5XXL model to {t5xxl_dtype}") + t5xxl.to(t5xxl_dtype) + + text_encoders = [clip_l, t5xxl] for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") + # build text encoder outputs caching strategy + if is_sdxl: + text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=False, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + # build text encoding strategy + if is_sdxl: + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + else: + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) + # cache text encoder outputs + train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") @@ -179,11 +190,20 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) config_util.add_config_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)", + ) parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser From 2244cf5b835cc35179f29b1babb4a2d19f54bfae Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 18:22:19 +0900 Subject: [PATCH 243/356] load images in parallel when caching latents --- library/train_util.py | 93 ++++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e6b3408..1db470d6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -1058,7 +1059,6 @@ class BaseDataset(torch.utils.data.Dataset): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] batch: List[ImageInfo] = [] current_condition = None @@ -1066,57 +1066,70 @@ class BaseDataset(torch.utils.data.Dataset): num_processes = accelerator.num_processes process_index = accelerator.process_index - logger.info("checking cache validity...") - for i, info in enumerate(tqdm(image_infos)): - subset = self.image_to_subset[info.image_key] + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) - if info.latents_npz is not None: # fine tuning dataset - continue + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different latents - if i % num_processes != process_index: + if info.latents_npz is not None: # fine tuning dataset continue - # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - batch.append(info) - current_condition = condition + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - if len(batch) > 0: - batches.append((current_condition, batch)) + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) - if len(batches) == 0: - logger.info("no latents to cache") - return + batch.append(info) + current_condition = condition - # iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): - caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと From bfc3a65acda7f90abef9c16db279d2952f73fb77 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:08:16 +0900 Subject: [PATCH 244/356] fix to work cache latents/text encoder outputs --- library/train_util.py | 11 +++++++---- tools/cache_latents.py | 11 ++++++----- tools/cache_text_encoder_outputs.py | 18 +++++++++++++----- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1db470d6..92660926 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4064,15 +4064,18 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) if args.v_parameterization and not args.v2: logger.warning( diff --git a/tools/cache_latents.py b/tools/cache_latents.py index d8154ec3..e2faa58a 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,7 +9,7 @@ from accelerate.utils import set_seed import torch from tqdm import tqdm -from library import config_util, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( @@ -30,7 +30,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa else: is_schnell = False - if is_sd or is_sdxl: + if is_sd: tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) elif is_sdxl: tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) @@ -51,6 +51,7 @@ def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argpa def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" args.cache_latents = True @@ -161,10 +162,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - parser.add_argument( - "--ae", type=str, default=None, help="Autoencoder model of FLUX to use / 使用するFLUXのオートエンコーダモデル" - ) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index d294d46c..7be9ad78 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -27,7 +27,7 @@ from library.config_util import ( BlueprintGenerator, ) from library.utils import setup_logging, add_logging_arguments -from tools import cache_latents +from cache_latents import set_tokenize_strategy setup_logging() import logging @@ -38,6 +38,7 @@ logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs_to_disk = True @@ -57,8 +58,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: assert ( is_sdxl or args.weighted_captions is None ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" - - cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する use_user_config = args.dataset_config is not None @@ -178,7 +179,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: @@ -188,9 +189,10 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( @@ -205,6 +207,12 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) return parser From 2d5f7fa709c31d07a1bb44b5be391c29b77d3cfc Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 13 Oct 2024 19:23:21 +0900 Subject: [PATCH 245/356] update README --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 544c665d..7fae50d1 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,15 @@ The command to install PyTorch is as follows: ### Recent Updates -Oct 12, 2024 (update 1): +Oct 13, 2024: + +- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. - During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. -- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. + - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. + - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). + - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. - `--skip_cache_check` option is added to each training script. - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. - Specify this option if you have a large number of cache files and the consistency check takes time. From 2500f5a79806fdbe74c43db24a95ee19329a8fcc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Oct 2024 07:16:34 +0900 Subject: [PATCH 246/356] fix latents caching not working closes #1696 --- fine_tune.py | 2 +- flux_train.py | 2 +- sd3_train.py | 2 +- sdxl_train.py | 2 +- sdxl_train_control_net.py | 2 +- train_db.py | 2 +- train_textual_inversion.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index cdc005d9..0b7cc510 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -177,7 +177,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/flux_train.py b/flux_train.py index 46a8babd..91ae3af5 100644 --- a/flux_train.py +++ b/flux_train.py @@ -190,7 +190,7 @@ def train(args): ae.requires_grad_(False) ae.eval() - train_dataset_group.new_cache_latents(ae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(ae, accelerator) ae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/sd3_train.py b/sd3_train.py index 7290956a..ef18c32c 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -243,7 +243,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/sdxl_train.py b/sdxl_train.py index 9b2d1916..79a2fbb6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 74b3a64a..24080afb 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -209,7 +209,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_db.py b/train_db.py index 683b4233..4a58e27b 100644 --- a/train_db.py +++ b/train_db.py @@ -156,7 +156,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4d8a3abb..77b5d717 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -378,7 +378,7 @@ class TextualInversionTrainer: vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() From 3cc5b8db99c66b9e205c4fd4a5f969090c51ef58 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 18 Oct 2024 20:57:13 +0900 Subject: [PATCH 247/356] Diff Output Preserv loss for SDXL --- library/config_util.py | 17 +++++++---------- library/train_util.py | 17 ++++++++++++++++- sdxl_train_network.py | 20 +++++++++++++++++++- train_network.py | 35 +++++++++++++++++++++++++---------- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index f8cdfe60..fc1fbf46 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -10,13 +10,7 @@ import json from pathlib import Path # from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous @@ -78,6 +72,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None @dataclass @@ -197,6 +192,7 @@ class ConfigSanitizer: "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "custom_attributes": dict, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -538,9 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask}, + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """ ), " ", diff --git a/library/train_util.py b/library/train_util.py index 4a446e81..7d3fce5b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -396,6 +396,7 @@ class BaseSubset: caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -419,6 +420,8 @@ class BaseSubset: self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.img_count = 0 @@ -449,6 +452,7 @@ class DreamBoothSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -473,6 +477,7 @@ class DreamBoothSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.is_reg = is_reg @@ -512,6 +517,7 @@ class FineTuningSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -536,6 +542,7 @@ class FineTuningSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.metadata_file = metadata_file @@ -1474,11 +1481,14 @@ class BaseDataset(torch.utils.data.Dataset): target_sizes_hw = [] flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] + custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] + custom_attributes.append(subset.custom_attributes) + # in case of fine tuning, is_reg is always False loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) @@ -1646,7 +1656,9 @@ class BaseDataset(torch.utils.data.Dataset): return None return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + # set example example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) @@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False): f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") # if show_input_ids: # logger.info(f"input ids: {iid}") @@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4a16a489..d45df6e0 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,4 +1,5 @@ import argparse +from typing import List, Optional import torch from accelerate import Accelerator @@ -172,7 +173,18 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return encoder_hidden_states1, encoder_hidden_states2, pool2 - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings @@ -186,6 +198,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred diff --git a/train_network.py b/train_network.py index d5330aef..ef766737 100644 --- a/train_network.py +++ b/train_network.py @@ -143,7 +143,7 @@ class NetworkTrainer: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -218,6 +218,30 @@ class NetworkTrainer: else: target = noise + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + return noise_pred, target, timesteps, huber_c, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): @@ -1123,15 +1147,6 @@ class NetworkTrainer: with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # # SD only - # encoded_text_encoder_conds = get_weighted_text_embeddings( - # tokenizers[0], - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, From d8d7142665a8f6b2d43827c9b3a6a2de009c09cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 18 Oct 2024 23:16:30 +0900 Subject: [PATCH 248/356] fix to work caching latents #1696 --- sdxl_train_control_net_lllite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 14ff7c24..913b1d43 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -181,7 +181,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) From ef70aa7b42b5c923cc1a8594b2f30487a2b4f700 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 18 Oct 2024 23:39:48 +0900 Subject: [PATCH 249/356] add FLUX.1 support --- README.md | 19 +++++++ flux_train_network.py | 119 +++++++++++++++++++++++++++++------------- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 7fae50d1..59f70ebc 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,25 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 19, 2024: + +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. + - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. + - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. + - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. + - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". + - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. + - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. +``` +[[datasets.subsets]] +image_dir = "path/to/image/dir" +num_repeats = 1 +is_reg = true +custom_attributes.diff_output_preservation = true # Add this +``` + + + Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index aa92fe3a..8431a6dc 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -373,33 +373,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if not args.apply_t5_attn_mask: t5_attn_mask = None - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, img_ids=img_ids, txt=t5_out, txt_ids=txt_ids, @@ -408,18 +388,52 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -430,6 +444,37 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # flow matching loss: this is different from SD3 target = noise - latents + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + return model_pred, target, timesteps, None, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): From 2c45d979e696fd4412ae1336feaee3bc9b967af4 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 19 Oct 2024 19:21:12 +0900 Subject: [PATCH 250/356] update README, remove unnecessary autocast --- README.md | 10 ++++------ flux_train_network.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 59f70ebc..32ee3857 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,13 @@ The command to install PyTorch is as follows: Oct 19, 2024: -- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". - - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. - - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. + - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". + - Specify a large value for `--prior_loss_weight` option (not dataset config). The appropriate value is unknown, but try around 10-100. Note that the default is 1.0. + - You may want to start with 2/3 to 3/4 of the loss value when DOP is not applied. If it is 1/2, DOP may not be working. ``` [[datasets.subsets]] image_dir = "path/to/image/dir" @@ -28,8 +28,6 @@ is_reg = true custom_attributes.diff_output_preservation = true # Add this ``` - - Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index 8431a6dc..9cc8811b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -453,7 +453,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): + with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices], img_ids=img_ids[diff_output_pr_indices], From 7fe8e162cb54ccf259eead1cca0ebdcc4e2b77fe Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Oct 2024 08:45:27 +0900 Subject: [PATCH 251/356] fix to work ControlNetSubset with custom_attributes --- library/train_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 7d3fce5b..462c7a9a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -578,6 +578,7 @@ class ControlNetSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -602,6 +603,7 @@ class ControlNetSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.conditioning_data_dir = conditioning_data_dir From 138dac4aea57716e2f23580305f6e40836a87228 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Oct 2024 09:22:38 +0900 Subject: [PATCH 252/356] update README --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 32ee3857..532c3368 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,9 @@ Oct 19, 2024: - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". - - Specify a large value for `--prior_loss_weight` option (not dataset config). The appropriate value is unknown, but try around 10-100. Note that the default is 1.0. - - You may want to start with 2/3 to 3/4 of the loss value when DOP is not applied. If it is 1/2, DOP may not be working. + - The weights of DOP is specified by `--prior_loss_weight` option (not dataset config). + - The appropriate value is still unknown. For FLUX, according to the comments in the [PR](https://github.com/kohya-ss/sd-scripts/pull/1710), the value may be 1 (thanks to dxqbYD!). For SDXL, a larger value may be needed (10-100 may be good starting points). + - It may be good to adjust the value so that the loss is about half to three-quarters of the loss when DOP is not applied. ``` [[datasets.subsets]] image_dir = "path/to/image/dir" @@ -28,6 +29,7 @@ is_reg = true custom_attributes.diff_output_preservation = true # Add this ``` + Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. From 8fc30f820595f80ec3f09738cc4cf01f441c41b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 07:34:33 -0400 Subject: [PATCH 253/356] Fix training for V-pred and ztSNR 1) Updates debiased estimation loss function for V-pred. 2) Prevents now-deprecated scaling of loss if ztSNR is enabled. --- fine_tune.py | 4 ++-- library/custom_train_functions.py | 7 +++++-- library/train_util.py | 5 +++++ sdxl_train.py | 4 ++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 10 files changed, 26 insertions(+), 18 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b556672d..19a35229 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,10 +383,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5..faf44304 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler): +def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss diff --git a/library/train_util.py b/library/train_util.py index 27910dc9..adb983d2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3731,6 +3731,11 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + + if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: + raise ValueError( + "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" + ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index e0a8f2b2..44ee9233 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,12 +725,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9..436f0e19 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,12 +474,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463..8fba9eba 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,12 +434,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 2c7f0258..d5a94a56 100644 --- a/train_db.py +++ b/train_db.py @@ -370,10 +370,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 044ec3aa..790fbfc9 100644 --- a/train_network.py +++ b/train_network.py @@ -993,12 +993,12 @@ class NetworkTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 96e7bd50..10b34db5 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,12 +598,12 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137..084b90c6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,10 +483,10 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From e1b63c2249345e4f14c10cbb252da68157ac13b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:12:53 -0400 Subject: [PATCH 254/356] Only add warning for deprecated scaling vpred loss function --- fine_tune.py | 2 +- library/train_util.py | 11 ++++++----- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 14 insertions(+), 13 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 19a35229..c79f97d2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,7 +383,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/library/train_util.py b/library/train_util.py index adb983d2..f479dcc6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,15 +3727,16 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + if args.scale_v_pred_loss_like_noise_pred: + logger.warning( + f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" + + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" + ) + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) - - if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: - raise ValueError( - "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" - ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index 44ee9233..b533b274 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,7 +725,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 436f0e19..0e67cde5 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,7 +474,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 8fba9eba..4a01f9e2 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,7 +434,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_db.py b/train_db.py index d5a94a56..e7cf3cde 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/train_network.py b/train_network.py index 790fbfc9..7bf125dc 100644 --- a/train_network.py +++ b/train_network.py @@ -993,7 +993,7 @@ class NetworkTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10b34db5..37349da7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,7 +598,7 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 084b90c6..fac0787b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,7 +483,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From 0e7c5929336173e30d7932c0706eaf61a7d396f4 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:19:34 -0400 Subject: [PATCH 255/356] Remove scale_v_pred_loss_like_noise_pred deprecation https://github.com/kohya-ss/sd-scripts/pull/1715#issuecomment-2427876376 --- library/train_util.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index f479dcc6..27910dc9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,12 +3727,6 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") - if args.scale_v_pred_loss_like_noise_pred: - logger.warning( - f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" - + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" - ) - if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" From be14c062674973d0e4fee1eb4527e04707bb72b8 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:13:51 -0400 Subject: [PATCH 256/356] Remove v-pred warnings Different model architectures, such as SDXL, can take advantage of v-pred. It doesn't make sense to include these warnings anymore. --- gen_img.py | 2 -- gen_img_diffusers.py | 2 -- library/train_util.py | 4 ---- 3 files changed, 8 deletions(-) diff --git a/gen_img.py b/gen_img.py index 59bcd5b0..9427a894 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1495,8 +1495,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c40f1a0..04db4e9b 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2216,8 +2216,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/library/train_util.py b/library/train_util.py index 27910dc9..100ef475 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3698,10 +3698,6 @@ def verify_training_args(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True - if args.v_parameterization and not args.v2: - logger.warning( - "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" - ) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") From 623017f71695bcee18f36f5a1f57514974d9350d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 19:49:28 +0900 Subject: [PATCH 257/356] refactor SD3 CLIP to transformers etc. --- flux_train.py | 4 +- flux_train_network.py | 2 +- library/flux_train_utils.py | 3 +- library/flux_utils.py | 59 +-- library/sai_model_spec.py | 9 +- library/sd3_models.py | 1002 ++--------------------------------- library/sd3_train_utils.py | 246 +++++---- library/sd3_utils.py | 503 +++++------------- library/strategy_sd3.py | 184 ++++--- library/train_util.py | 31 ++ library/utils.py | 42 +- sd3_minimal_inference.py | 450 ++++++++-------- sd3_train.py | 816 +++++++++++++++------------- 13 files changed, 1201 insertions(+), 2150 deletions(-) diff --git a/flux_train.py b/flux_train.py index 91ae3af5..79c44d7b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -29,7 +29,7 @@ init_ipex() from accelerate.utils import set_seed from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util @@ -241,7 +241,7 @@ def train(args): text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: diff --git a/flux_train_network.py b/flux_train_network.py index 9cc8811b..cffeb3b1 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -231,7 +231,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - prompts = sd3_train_utils.load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index b3c9184f..fa673a2f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -15,7 +15,6 @@ from PIL import Image from safetensors.torch import save_file from library import flux_models, flux_utils, strategy_base, train_util -from library.sd3_train_utils import load_prompts from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -70,7 +69,7 @@ def sample_images( text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) diff --git a/library/flux_utils.py b/library/flux_utils.py index 7a1ec37b..86a2ec60 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -10,40 +10,21 @@ from safetensors import safe_open from accelerate import init_empty_weights from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config -from library import flux_models - -from library.utils import setup_logging, MemoryEfficientSafeOpen +from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +from library import flux_models +from library.utils import load_safetensors + MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" -# temporary copy from sd3_utils TODO refactor -def load_safetensors( - path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 -): - if disable_mmap: - # return safetensors.torch.load(open(path, "rb").read()) - # use experimental loader - logger.info(f"Loading without mmap (experimental)") - state_dict = {} - with MemoryEfficientSafeOpen(path) as f: - for key in f.keys(): - state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) - return state_dict - else: - try: - return load_file(path, device=device) - except: - return load_file(path) # prevent device invalid Error - - def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 @@ -161,8 +142,14 @@ def load_ae( return ae -def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel: - logger.info("Building CLIP") +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") CLIPL_CONFIG = { "_name_or_path": "clip-vit-large-patch14/", "architectures": ["CLIPModel"], @@ -255,15 +242,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev with init_empty_weights(): clip = CLIPTextModel._from_config(config) - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = clip.load_state_dict(sd, strict=False, assign=True) - logger.info(f"Loaded CLIP: {info}") + logger.info(f"Loaded CLIP-L: {info}") return clip def load_t5xxl( - ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, ) -> T5EncoderModel: T5_CONFIG_JSON = """ { @@ -303,8 +297,11 @@ def load_t5xxl( with init_empty_weights(): t5xxl = T5EncoderModel._from_config(config) - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = t5xxl.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded T5xxl: {info}") return t5xxl diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index ad72ec00..8896c047 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -57,8 +57,8 @@ ARCH_SD_V1 = "stable-diffusion-v1" ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" -ARCH_SD3_M = "stable-diffusion-3-medium" -ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. +# ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" ARCH_FLUX_1_UNKNOWN = "flux-1" @@ -140,10 +140,7 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE elif sd3 is not None: - if sd3 == "m": - arch = ARCH_SD3_M - else: - arch = ARCH_SD3_UNKNOWN + arch = ARCH_SD3_M + "-" + sd3 elif flux is not None: if flux == "dev": arch = ARCH_FLUX_1_DEV diff --git a/library/sd3_models.py b/library/sd3_models.py index ec704dcb..c81aa479 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -4,6 +4,7 @@ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! from ast import Tuple +from dataclasses import dataclass from functools import partial import math from types import SimpleNamespace @@ -15,6 +16,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast + from .utils import setup_logging setup_logging() @@ -35,141 +37,23 @@ except: memory_efficient_attention = None -# region tokenizer -class SDTokenizer: - def __init__( - self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None - ): - """ - サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。 - Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings. - """ - self.tokenizer: CLIPTokenizer = tokenizer - self.max_length = max_length - self.min_length = min_length - empty = self.tokenizer("")["input_ids"] - if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] - self.end_token = empty[1] - else: - self.tokens_start = 0 - self.start_token = None - self.end_token = empty[0] - self.pad_with_end = pad_with_end - self.pad_to_max_length = pad_to_max_length - vocab = self.tokenizer.get_vocab() - self.inv_vocab = {v: k for k, v in vocab.items()} - self.max_word_length = 8 - - def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: - """ - Tokenize the text without weights. - """ - if type(text) == str: - text = [text] - batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt") - # return tokens["input_ids"] - - pad_token = self.end_token if self.pad_with_end else 0 - for tokens in batch_tokens["input_ids"]: - assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}" - - def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None): - """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. - The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" - """ - ja: テキストをトークン化し、重み値を持ちます - すべての値に1.0を仮定し、他の機能を無視します。 - 詳細は参考実装には関係なく、重み自体はSD3に対して弱い影響しかありません。へぇ~ - """ - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0)) - to_tokenize = text.replace("\n", " ").split(" ") - to_tokenize = [x for x in to_tokenize if x != ""] - for word in to_tokenize: - batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]]) - batch.append((self.end_token, 1.0)) - print(len(batch), self.max_length, self.min_length) - if self.pad_to_max_length: - batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) - - # truncate to max_length - print( - f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}" - ) - if truncate_to_max_length and len(batch) > self.max_length: - batch = batch[: self.max_length] - if truncate_length is not None and len(batch) > truncate_length: - batch = batch[:truncate_length] - - return [batch] - - -class T5XXLTokenizer(SDTokenizer): - """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" - - def __init__(self): - super().__init__( - pad_with_end=False, - tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), - has_start_token=False, - pad_to_max_length=False, - max_length=99999999, - min_length=77, - ) - - -class SDXLClipGTokenizer(SDTokenizer): - def __init__(self, tokenizer): - super().__init__(pad_with_end=False, tokenizer=tokenizer) - - -class SD3Tokenizer: - def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256): - if t5xxl_max_length is None: - t5xxl_max_length = 256 - - # TODO cache tokenizer settings locally or hold them in the repo like ComfyUI - clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) - self.clip_g = SDXLClipGTokenizer(clip_tokenizer) - # self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - # self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - self.t5xxl = T5XXLTokenizer() if t5xxl else None - # t5xxl has 99999999 max length, clip has 77 - self.t5xxl_max_length = t5xxl_max_length - - def tokenize_with_weights(self, text: str): - return ( - self.clip_l.tokenize_with_weights(text), - self.clip_g.tokenize_with_weights(text), - ( - self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length) - if self.t5xxl is not None - else None - ), - ) - - def tokenize(self, text: str): - return ( - self.clip_l.tokenize(text), - self.clip_g.tokenize(text), - (self.t5xxl.tokenize(text) if self.t5xxl is not None else None), - ) - - -# endregion - # region mmdit +@dataclass +class SD3Params: + patch_size: int + depth: int + num_patches: int + pos_embed_max_size: int + adm_in_channels: int + qk_norm: Optional[str] + x_block_self_attn_layers: List[int] + context_embedder_in_features: int + context_embedder_out_features: int + model_type: str + + def get_2d_sincos_pos_embed( embed_dim, grid_size, @@ -286,10 +170,6 @@ def timestep_embedding(t, dim, max_period=10000): return embedding -def rmsnorm(x, eps=1e-6): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - class PatchEmbed(nn.Module): def __init__( self, @@ -301,8 +181,9 @@ class PatchEmbed(nn.Module): flatten=True, bias=True, strict_img_size=True, - dynamic_img_pad=True, + dynamic_img_pad=False, ): + # dynamic_img_pad and norm is omitted in SD3.5 super().__init__() self.patch_size = patch_size self.flatten = flatten @@ -432,6 +313,10 @@ class Embedder(nn.Module): return self.mlp(x) +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + class RMSNorm(torch.nn.Module): def __init__( self, @@ -604,53 +489,6 @@ def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): return scores -class SelfAttention(AttentionLinears): - def __init__(self, dim, num_heads=8, mode="xformers"): - super().__init__(dim, num_heads, qkv_bias=True, pre_only=False) - assert mode in MEMORY_LAYOUTS - self.head_dim = dim // num_heads - self.attn_mode = mode - - def set_attn_mode(self, mode): - self.attn_mode = mode - - def forward(self, x): - q, k, v = self.pre_attention(x) - attn_score = attention(q, k, v, self.head_dim, mode=self.attn_mode) - return self.post_attention(attn_score) - - -class TransformerBlock(nn.Module): - def __init__(self, context_size, mode="xformers"): - super().__init__() - self.context_size = context_size - self.norm1 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - self.attn = SelfAttention(context_size, mode=mode) - self.norm2 = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - self.mlp = MLP( - in_features=context_size, - hidden_features=context_size * 4, - act_layer=lambda: nn.GELU(approximate="tanh"), - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, context_size, num_layers, mode="xformers"): - super().__init__() - self.layers = nn.ModuleList([TransformerBlock(context_size, mode) for _ in range(num_layers)]) - self.norm = nn.LayerNorm(context_size, elementwise_affine=False, eps=1e-6) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return self.norm(x) - - # DismantledBlock in mmdit.py class SingleDiTBlock(nn.Module): """ @@ -823,7 +661,8 @@ class MMDiT(nn.Module): mlp_ratio: float = 4.0, learn_sigma: bool = False, adm_in_channels: Optional[int] = None, - context_embedder_config: Optional[Dict] = None, + context_embedder_in_features: Optional[int] = None, + context_embedder_out_features: Optional[int] = None, use_checkpoint: bool = False, register_length: int = 0, attn_mode: str = "torch", @@ -837,10 +676,10 @@ class MMDiT(nn.Module): num_patches=None, qk_norm: Optional[str] = None, qkv_bias: bool = True, - context_processor_layers=None, - context_size=4096, + model_type: str = "sd3m", ): super().__init__() + self._model_type = model_type self.learn_sigma = learn_sigma self.in_channels = in_channels default_out_channels = in_channels * 2 if learn_sigma else in_channels @@ -875,12 +714,11 @@ class MMDiT(nn.Module): assert isinstance(adm_in_channels, int) self.y_embedder = Embedder(adm_in_channels, self.hidden_size) - if context_processor_layers is not None: - self.context_processor = Transformer(context_size, context_processor_layers, attn_mode) + if context_embedder_in_features is not None: + self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) else: - self.context_processor = None + self.context_embedder = nn.Identity() - self.context_embedder = nn.Linear(context_size, self.hidden_size) self.register_length = register_length if self.register_length > 0: self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) @@ -922,7 +760,7 @@ class MMDiT(nn.Module): @property def model_type(self): - return "m" # only support medium + return self._model_type @property def device(self): @@ -1024,9 +862,6 @@ class MMDiT(nn.Module): y: (N, D) tensor of class labels """ - if self.context_processor is not None: - context = self.context_processor(context) - B, C, H, W = x.shape x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) c = self.t_embedder(t, dtype=x.dtype) # (N, D) @@ -1052,22 +887,21 @@ class MMDiT(nn.Module): return x[:, :, :H, :W] -def create_mmdit_sd3_medium_configs(attn_mode: str): - # {'patch_size': 2, 'depth': 24, 'num_patches': 36864, - # 'pos_embed_max_size': 192, 'adm_in_channels': 2048, 'context_embedder': - # {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}} +def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: mmdit = MMDiT( input_size=None, - pos_embed_max_size=192, - patch_size=2, + pos_embed_max_size=params.pos_embed_max_size, + patch_size=params.patch_size, in_channels=16, - adm_in_channels=2048, - depth=24, + adm_in_channels=params.adm_in_channels, + context_embedder_in_features=params.context_embedder_in_features, + context_embedder_out_features=params.context_embedder_out_features, + depth=params.depth, mlp_ratio=4, - qk_norm=None, - num_patches=36864, - context_size=4096, + qk_norm=params.qk_norm, + num_patches=params.num_patches, attn_mode=attn_mode, + model_type=params.model_type, ) return mmdit @@ -1075,7 +909,6 @@ def create_mmdit_sd3_medium_configs(attn_mode: str): # endregion # region VAE -# TODO support xformers VAE_SCALE_FACTOR = 1.5305 VAE_SHIFT_FACTOR = 0.0609 @@ -1322,759 +1155,4 @@ class SDVAE(torch.nn.Module): return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR -class VAEOutput: - def __init__(self, latent): - self.latent = latent - - @property - def latent_dist(self): - return self - - def sample(self): - return self.latent - - -class VAEWrapper: - def __init__(self, vae): - self.vae = vae - - @property - def device(self): - return self.vae.device - - @property - def dtype(self): - return self.vae.dtype - - # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - def encode(self, image): - return VAEOutput(self.vae.encode(image)) - - -# endregion - - -# region Text Encoder -class CLIPAttention(torch.nn.Module): - def __init__(self, embed_dim, heads, dtype, device, mode="xformers"): - super().__init__() - self.heads = heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.attn_mode = mode - - def set_attn_mode(self, mode): - self.attn_mode = mode - - def forward(self, x, mask=None): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - out = attention(q, k, v, self.heads, mask, mode=self.attn_mode) - return self.out_proj(out) - - -ACTIVATIONS = { - "quick_gelu": lambda: (lambda a: a * torch.sigmoid(1.702 * a)), - # "gelu": torch.nn.functional.gelu, - "gelu": lambda: nn.GELU(), -} - - -class CLIPLayer(torch.nn.Module): - def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): - super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - # # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) - # self.mlp = Mlp( - # embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device - # ) - self.mlp = MLP(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation]) - self.mlp.to(device=device, dtype=dtype) - - def forward(self, x, mask=None): - x += self.self_attn(self.layer_norm1(x), mask) - x += self.mlp(self.layer_norm2(x)) - return x - - -class CLIPEncoder(torch.nn.Module): - def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): - super().__init__() - self.layers = torch.nn.ModuleList( - [CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)] - ) - - def forward(self, x, mask=None, intermediate_output=None): - if intermediate_output is not None: - if intermediate_output < 0: - intermediate_output = len(self.layers) + intermediate_output - intermediate = None - for i, l in enumerate(self.layers): - x = l(x, mask) - if i == intermediate_output: - intermediate = x.clone() - return x, intermediate - - -class CLIPEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): - super().__init__() - self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) - self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) - - def forward(self, input_tokens): - return self.token_embedding(input_tokens) + self.position_embedding.weight - - -class CLIPTextModel_(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - num_layers = config_dict["num_hidden_layers"] - embed_dim = config_dict["hidden_size"] - heads = config_dict["num_attention_heads"] - intermediate_size = config_dict["intermediate_size"] - intermediate_activation = config_dict["hidden_act"] - super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) - self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - - def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): - x = self.embeddings(input_tokens) - - if x.dtype == torch.bfloat16: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=torch.float32, device=x.device).fill_(float("-inf")).triu_(1) - causal_mask = causal_mask.to(dtype=x.dtype) - else: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - - x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output) - x = self.final_layer_norm(x) - if i is not None and final_layer_norm_intermediate: - i = self.final_layer_norm(i) - pooled_output = x[ - torch.arange(x.shape[0], device=x.device), - input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), - ] - return x, i, pooled_output - - -class CLIPTextModel(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_hidden_layers"] - self.text_model = CLIPTextModel_(config_dict, dtype, device) - embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) - self.text_projection.weight.copy_(torch.eye(embed_dim)) - self.dtype = dtype - - def get_input_embeddings(self): - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, embeddings): - self.text_model.embeddings.token_embedding = embeddings - - def forward(self, *args, **kwargs): - x = self.text_model(*args, **kwargs) - out = self.text_projection(x[2]) - return (x[0], x[1], out, x[2]) - - -class ClipTokenWeightEncoder: - # def encode_token_weights(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - # out, pooled = self([tokens]) - # if pooled is not None: - # first_pooled = pooled[0:1] - # else: - # first_pooled = pooled - # output = [out[0:1]] - # return torch.cat(output, dim=-2), first_pooled - - # fix to support batched inputs - # : Union[List[Tuple[torch.Tensor, torch.Tensor]], List[List[Tuple[torch.Tensor, torch.Tensor]]]] - def encode_token_weights(self, list_of_token_weight_pairs): - has_batch = isinstance(list_of_token_weight_pairs[0][0], list) - - if has_batch: - list_of_tokens = [] - for pairs in list_of_token_weight_pairs: - tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] - list_of_tokens.append(tokens) - else: - if isinstance(list_of_token_weight_pairs[0], torch.Tensor): - list_of_tokens = [list(list_of_token_weight_pairs[0])] - else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] - - out, pooled = self(list_of_tokens) - if has_batch: - return out, pooled - else: - if pooled is not None: - first_pooled = pooled[0:1] - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2), first_pooled - - -class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - device="cpu", - max_length=77, - layer="last", - layer_idx=None, - textmodel_json_config=None, - dtype=None, - model_class=CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, - layer_norm_hidden_state=True, - return_projected_pooled=True, - ): - super().__init__() - assert layer in self.LAYERS - self.transformer = model_class(textmodel_json_config, dtype, device) - self.num_layers = self.transformer.num_layers - self.max_length = max_length - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.layer_norm_hidden_state = layer_norm_hidden_state - self.return_projected_pooled = return_projected_pooled - if layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < self.num_layers - self.set_clip_options({"layer": layer_idx}) - self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled) - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def gradient_checkpointing_enable(self): - logger.warning("Gradient checkpointing is not supported for this model") - - def set_attn_mode(self, mode): - raise NotImplementedError("This model does not support setting the attention mode") - - def set_clip_options(self, options): - layer_idx = options.get("layer", self.layer_idx) - self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: - self.layer = "last" - else: - self.layer = "hidden" - self.layer_idx = layer_idx - - def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = torch.LongTensor(tokens).to(device) - outputs = self.transformer( - tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state - ) - self.transformer.set_input_embeddings(backup_embeds) - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] - pooled_output = None - if len(outputs) >= 3: - if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None: - pooled_output = outputs[3].float() - elif outputs[2] is not None: - pooled_output = outputs[2].float() - return z.float(), pooled_output - - def set_attn_mode(self, mode): - clip_text_model = self.transformer.text_model - for layer in clip_text_model.encoder.layers: - layer.self_attn.set_attn_mode(mode) - - -class SDXLClipG(SDClipModel): - """Wraps the CLIP-G model into the SD-CLIP-Model interface""" - - def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None): - if layer == "penultimate": - layer = "hidden" - layer_idx = -2 - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, - layer_norm_hidden_state=False, - ) - - def set_attn_mode(self, mode): - clip_text_model = self.transformer.text_model - for layer in clip_text_model.encoder.layers: - layer.self_attn.set_attn_mode(mode) - - -class T5XXLModel(SDClipModel): - """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" - - def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"end": 1, "pad": 0}, - model_class=T5, - ) - - def set_attn_mode(self, mode): - t5: T5 = self.transformer - for t5block in t5.encoder.block: - t5block: T5Block - t5layer: T5LayerSelfAttention = t5block.layer[0] - t5SaSa: T5Attention = t5layer.SelfAttention - t5SaSa.set_attn_mode(mode) - - -################################################################################################# -### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl -################################################################################################# - -""" -class T5XXLTokenizer(SDTokenizer): - ""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"" - - def __init__(self): - super().__init__( - pad_with_end=False, - tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), - has_start_token=False, - pad_to_max_length=False, - max_length=99999999, - min_length=77, - ) -""" - - -class T5LayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) - self.variance_epsilon = eps - - # def forward(self, x): - # variance = x.pow(2).mean(-1, keepdim=True) - # x = x * torch.rsqrt(variance + self.variance_epsilon) - # return self.weight.to(device=x.device, dtype=x.dtype) * x - - # copy from transformers' T5LayerNorm - def forward(self, hidden_states): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class T5DenseGatedActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device): - super().__init__() - self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) - self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) - - def forward(self, x): - hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") - hidden_linear = self.wi_1(x) - x = hidden_gelu * hidden_linear - x = self.wo(x) - return x - - -class T5LayerFF(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device): - super().__init__() - self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) - self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, x): - forwarded_states = self.layer_norm(x) - forwarded_states = self.DenseReluDense(forwarded_states) - x += forwarded_states - return x - - -class T5Attention(torch.nn.Module): - def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - # Mesh TensorFlow initialization to avoid scaling before softmax - self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) - self.num_heads = num_heads - self.relative_attention_bias = None - if relative_attention_bias: - self.relative_attention_num_buckets = 32 - self.relative_attention_max_distance = 128 - self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device) - - self.attn_mode = "xformers" # TODO 何とかする - - def set_attn_mode(self, mode): - self.attn_mode = mode - - @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) - ) - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - def compute_bias(self, query_length, key_length, device): - """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=True, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) - return values - - def forward(self, x, past_bias=None): - q = self.q(x) - k = self.k(x) - v = self.v(x) - if self.relative_attention_bias is not None: - past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) - if past_bias is not None: - mask = past_bias - out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask, mode=self.attn_mode) - return self.o(out), past_bias - - -class T5LayerSelfAttention(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device) - self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, x, past_bias=None): - output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) - x += output - return x, past_bias - - -class T5Block(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device): - super().__init__() - self.layer = torch.nn.ModuleList() - self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device)) - self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) - - def forward(self, x, past_bias=None): - x, past_bias = self.layer[0](x, past_bias) - - # copy from transformers' T5Block - # clamp inf values to enable fp16 training - if x.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(x).any(), - torch.finfo(x.dtype).max - 1000, - torch.finfo(x.dtype).max, - ) - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - - x = self.layer[-1](x) - # clamp inf values to enable fp16 training - if x.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(x).any(), - torch.finfo(x.dtype).max - 1000, - torch.finfo(x.dtype).max, - ) - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - - return x, past_bias - - -class T5Stack(torch.nn.Module): - def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device): - super().__init__() - self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) - self.block = torch.nn.ModuleList( - [ - T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) - for i in range(num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) - - def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True): - intermediate = None - x = self.embed_tokens(input_ids) - past_bias = None - for i, l in enumerate(self.block): - # uncomment to debug layerwise output: fp16 may cause issues - # print(i, x.mean(), x.std()) - x, past_bias = l(x, past_bias) - if i == intermediate_output: - intermediate = x.clone() - # print(x.mean(), x.std()) - x = self.final_layer_norm(x) - if intermediate is not None and final_layer_norm_intermediate: - intermediate = self.final_layer_norm(intermediate) - # print(x.mean(), x.std()) - return x, intermediate - - -class T5(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_layers"] - self.encoder = T5Stack( - self.num_layers, - config_dict["d_model"], - config_dict["d_model"], - config_dict["d_ff"], - config_dict["num_heads"], - config_dict["vocab_size"], - dtype, - device, - ) - self.dtype = dtype - - def get_input_embeddings(self): - return self.encoder.embed_tokens - - def set_input_embeddings(self, embeddings): - self.encoder.embed_tokens = embeddings - - def forward(self, *args, **kwargs): - return self.encoder(*args, **kwargs) - - -def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): - r""" - state_dict is not loaded, but updated with missing keys - """ - CLIPL_CONFIG = { - "hidden_act": "quick_gelu", - "hidden_size": 768, - "intermediate_size": 3072, - "num_attention_heads": 12, - "num_hidden_layers": 12, - } - with torch.no_grad(): - clip_l = SDClipModel( - layer="hidden", - layer_idx=-2, - device=device, - dtype=dtype, - layer_norm_hidden_state=False, - return_projected_pooled=False, - textmodel_json_config=CLIPL_CONFIG, - ) - clip_l.gradient_checkpointing_enable() - if state_dict is not None: - # update state_dict if provided to include logit_scale and text_projection.weight avoid errors - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = clip_l.logit_scale - if "transformer.text_projection.weight" not in state_dict: - state_dict["transformer.text_projection.weight"] = clip_l.transformer.text_projection.weight - return clip_l - - -def create_clip_g(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None): - r""" - state_dict is not loaded, but updated with missing keys - """ - CLIPG_CONFIG = { - "hidden_act": "gelu", - "hidden_size": 1280, - "intermediate_size": 5120, - "num_attention_heads": 20, - "num_hidden_layers": 32, - } - with torch.no_grad(): - clip_g = SDXLClipG(CLIPG_CONFIG, device=device, dtype=dtype) - if state_dict is not None: - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = clip_g.logit_scale - return clip_g - - -def create_t5xxl(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> T5XXLModel: - T5_CONFIG = {"d_ff": 10240, "d_model": 4096, "num_heads": 64, "num_layers": 24, "vocab_size": 32128} - with torch.no_grad(): - t5 = T5XXLModel(T5_CONFIG, dtype=dtype, device=device) - if state_dict is not None: - if "logit_scale" not in state_dict: - state_dict["logit_scale"] = t5.logit_scale - if "transformer.shared.weight" in state_dict: - state_dict.pop("transformer.shared.weight") - return t5 - - -""" - # snippet for using the T5 model from transformers - - from transformers import T5EncoderModel, T5Config - import accelerate - import json - - T5_CONFIG_JSON = "" -{ - "architectures": [ - "T5EncoderModel" - ], - "classifier_dropout": 0.0, - "d_ff": 10240, - "d_kv": 64, - "d_model": 4096, - "decoder_start_token_id": 0, - "dense_act_fn": "gelu_new", - "dropout_rate": 0.1, - "eos_token_id": 1, - "feed_forward_proj": "gated-gelu", - "initializer_factor": 1.0, - "is_encoder_decoder": true, - "is_gated_act": true, - "layer_norm_epsilon": 1e-06, - "model_type": "t5", - "num_decoder_layers": 24, - "num_heads": 64, - "num_layers": 24, - "output_past": true, - "pad_token_id": 0, - "relative_attention_max_distance": 128, - "relative_attention_num_buckets": 32, - "tie_word_embeddings": false, - "torch_dtype": "float16", - "transformers_version": "4.41.2", - "use_cache": true, - "vocab_size": 32128 -} -"" - config = json.loads(T5_CONFIG_JSON) - config = T5Config(**config) - - # model = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3") - # print(model.config) - # # model(**load_model.config) - - # with accelerate.init_empty_weights(): - model = T5EncoderModel._from_config(config) # , torch_dtype=dtype) - for key in list(state_dict.keys()): - if key.startswith("transformer."): - new_key = key[len("transformer.") :] - state_dict[new_key] = state_dict.pop(key) - - info = model.load_state_dict(state_dict) - print(info) - model.set_attn_mode = lambda x: None - # model.to("cpu") - - _self = model - - def enc(list_of_token_weight_pairs): - has_batch = isinstance(list_of_token_weight_pairs[0][0], list) - - if has_batch: - list_of_tokens = [] - for pairs in list_of_token_weight_pairs: - tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0] - list_of_tokens.append(tokens) - else: - list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]] - - list_of_tokens = np.array(list_of_tokens) - list_of_tokens = torch.from_numpy(list_of_tokens).to("cuda", dtype=torch.long) - out = _self(list_of_tokens) - pooled = None - if has_batch: - return out, pooled - else: - if pooled is not None: - first_pooled = pooled[0:1] - else: - first_pooled = pooled - return out[0], first_pooled - # output = [out[0:1]] - # return torch.cat(output, dim=-2), first_pooled - - model.encode_token_weights = enc - - return model -""" - # endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index e819d440..9282482d 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -11,8 +11,8 @@ from safetensors.torch import save_file from accelerate import Accelerator, PartialState from tqdm import tqdm from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel -from library import sd3_models, sd3_utils, strategy_base, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -28,60 +28,16 @@ import logging logger = logging.getLogger(__name__) -from .sdxl_train_util import match_mixed_precision - - -def load_target_model( - model_type: str, - args: argparse.Namespace, - state_dict: dict, - accelerator: Accelerator, - attn_mode: str, - model_dtype: Optional[torch.dtype], - device: Optional[torch.device], -) -> Union[ - sd3_models.MMDiT, - Optional[sd3_models.SDClipModel], - Optional[sd3_models.SDXLClipG], - Optional[sd3_models.T5XXLModel], - sd3_models.SDVAE, -]: - loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") - - for pi in range(accelerator.state.num_processes): - if pi == accelerator.state.local_process_index: - logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - - if model_type == "mmdit": - model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) - elif model_type == "clip_l": - model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) - elif model_type == "clip_g": - model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) - elif model_type == "t5xxl": - model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) - elif model_type == "vae": - model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) - else: - raise ValueError(f"Unknown model type: {model_type}") - - # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device - if args.lowram: - model = model.to(accelerator.device) - - clean_memory_on_device(accelerator.device) - accelerator.wait_for_everyone() - - return model +from library import sd3_models, sd3_utils, strategy_base, train_util def save_models( ckpt_path: str, - mmdit: sd3_models.MMDiT, - vae: sd3_models.SDVAE, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + mmdit: Optional[sd3_models.MMDiT], + vae: Optional[sd3_models.SDVAE], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], sai_metadata: Optional[dict], save_dtype: Optional[torch.dtype] = None, ): @@ -101,24 +57,35 @@ def save_models( update_sd("model.diffusion_model.", mmdit.state_dict()) update_sd("first_stage_model.", vae.state_dict()) - if clip_l is not None: - update_sd("text_encoders.clip_l.", clip_l.state_dict()) - if clip_g is not None: - update_sd("text_encoders.clip_g.", clip_g.state_dict()) - if t5xxl is not None: - update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + # do not support unified checkpoint format for now + # if clip_l is not None: + # update_sd("text_encoders.clip_l.", clip_l.state_dict()) + # if clip_g is not None: + # update_sd("text_encoders.clip_g.", clip_g.state_dict()) + # if t5xxl is not None: + # update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) save_file(state_dict, ckpt_path, metadata=sai_metadata) + if clip_l is not None: + clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors") + save_file(clip_l.state_dict(), clip_l_path) + if clip_g is not None: + clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors") + save_file(clip_g.state_dict(), clip_g_path) + if t5xxl is not None: + t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") + save_file(t5xxl.state_dict(), t5xxl_path) + def save_sd3_model_on_train_end( args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], mmdit: sd3_models.MMDiT, vae: sd3_models.SDVAE, ): @@ -141,9 +108,9 @@ def save_sd3_model_on_epoch_end_or_stepwise( epoch: int, num_train_epochs: int, global_step: int, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], mmdit: sd3_models.MMDiT, vae: sd3_models.SDVAE, ): @@ -208,23 +175,27 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", ) parser.add_argument( - "--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する" + "--save_clip", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", ) parser.add_argument( - "--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する" + "--save_t5xxl", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", ) parser.add_argument( "--t5xxl_device", type=str, default=None, - help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", ) parser.add_argument( "--t5xxl_dtype", type=str, default=None, - help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", ) # copy from Diffusers @@ -233,16 +204,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd", ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") parser.add_argument( "--mode_scale", type=float, default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) @@ -283,7 +263,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # temporary copied from sd3_minimal_inferece.py -def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): +def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): start = sampling.timestep(sampling.sigma_max) end = sampling.timestep(sampling.sigma_min) timesteps = torch.linspace(start, end, steps) @@ -327,7 +307,7 @@ def do_sample( model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 - sigmas = get_sigmas(model_sampling, steps).to(device) + sigmas = get_all_sigmas(model_sampling, steps).to(device) noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) @@ -371,37 +351,6 @@ def do_sample( return x -def load_prompts(prompt_file: str) -> List[Dict]: - # read prompts - if prompt_file.endswith(".txt"): - with open(prompt_file, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif prompt_file.endswith(".toml"): - with open(prompt_file, "r", encoding="utf-8") as f: - data = toml.load(f) - prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] - elif prompt_file.endswith(".json"): - with open(prompt_file, "r", encoding="utf-8") as f: - prompts = json.load(f) - - # preprocess prompts - for i in range(len(prompts)): - prompt_dict = prompts[i] - if isinstance(prompt_dict, str): - from library.train_util import line_to_prompt_dict - - prompt_dict = line_to_prompt_dict(prompt_dict) - prompts[i] = prompt_dict - assert isinstance(prompt_dict, dict) - - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) - - return prompts - - def sample_images( accelerator: Accelerator, args: argparse.Namespace, @@ -440,7 +389,7 @@ def sample_images( text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) - prompts = load_prompts(args.sample_prompts) + prompts = train_util.load_prompts(args.sample_prompts) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) @@ -510,7 +459,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, mmdit: sd3_models.MMDiT, - text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]], + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], vae: sd3_models.SDVAE, save_dir, prompt_dict, @@ -568,7 +517,7 @@ def sample_image_inference( l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - lg_out, t5_out, pooled = te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # encode negative prompts @@ -578,7 +527,7 @@ def sample_image_inference( l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - lg_out, t5_out, pooled = neg_te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image @@ -609,14 +558,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # region Diffusers @@ -886,4 +830,78 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return self.config.num_train_timesteps +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + # endregion diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 5849518f..9ad995d8 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -1,9 +1,12 @@ +from dataclasses import dataclass import math -from typing import Dict, Optional, Union +import re +from typing import Dict, List, Optional, Union import torch import safetensors from safetensors.torch import load_file from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig from .utils import setup_logging @@ -19,18 +22,61 @@ from library import sdxl_model_util # region models +# TODO remove dependency on flux_utils +from library.utils import load_safetensors +from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl -def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): - if disable_mmap: - return safetensors.torch.load(open(path, "rb").read()) + +def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): + logger.info(f"Analyzing state dict state...") + + # analyze configs + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None + + # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) + x_block_self_attn_layers = [] + re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight") + for key in list(state_dict.keys()): + m = re_attn.match(key) + if m: + x_block_self_attn_layers.append(int(m.group(1))) + + assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported" + + context_embedder_in_features = context_shape[1] + context_embedder_out_features = context_shape[0] + + # only supports 3-5-large and 3-medium + if qk_norm is not None: + model_type = "3-5-large" else: - try: - return load_file(path, device=dvc) - except: - return load_file(path) # prevent device invalid Error + model_type = "3-medium" + + params = sd3_models.SD3Params( + patch_size=patch_size, + depth=depth, + num_patches=num_patches, + pos_embed_max_size=pos_embed_max_size, + adm_in_channels=adm_in_channels, + qk_norm=qk_norm, + x_block_self_attn_layers=x_block_self_attn_layers, + context_embedder_in_features=context_embedder_in_features, + context_embedder_out_features=context_embedder_out_features, + model_type=model_type, + ) + logger.info(f"Analyzed state dict state: {params}") + return params -def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): +def load_mmdit( + state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch" +) -> sd3_models.MMDiT: mmdit_sd = {} mmdit_prefix = "model.diffusion_model." @@ -40,8 +86,9 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc # load MMDiT logger.info("Building MMDit") + params = analyze_state_dict_state(mmdit_sd) with init_empty_weights(): - mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) logger.info("Loading state dict...") info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) @@ -50,20 +97,14 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc def load_clip_l( - state_dict: Dict, clip_l_path: Optional[str], - attn_mode: str, - clip_dtype: Optional[Union[str, torch.dtype]], + dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): clip_l_sd = None - if clip_l_path: - logger.info(f"Loading clip_l from {clip_l_path}...") - clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - else: + if clip_l_path is None: if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_l: remove prefix "text_encoders.clip_l." logger.info("clip_l is included in the checkpoint") @@ -72,34 +113,58 @@ def load_clip_l( for k in list(state_dict.keys()): if k.startswith(prefix): clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_l_path is None: + logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided") + return None + + # load clip_l + logger.info("Building CLIP-L") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) if clip_l_sd is None: - clip_l = None - else: - logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded ClipL: {info}") - clip_l.set_attn_mode(attn_mode) - return clip_l + logger.info(f"Loading state dict from {clip_l_path}") + clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + if "text_projection.weight" not in clip_l_sd: + logger.info("Adding text_projection.weight to clip_l_sd") + clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device) + + info = clip.load_state_dict(clip_l_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip def load_clip_g( - state_dict: Dict, clip_g_path: Optional[str], - attn_mode: str, - clip_dtype: Optional[Union[str, torch.dtype]], + dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): clip_g_sd = None - if clip_g_path: - logger.info(f"Loading clip_g from {clip_g_path}...") - clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - else: + if state_dict is not None: if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: # found clip_g: remove prefix "text_encoders.clip_g." logger.info("clip_g is included in the checkpoint") @@ -108,34 +173,53 @@ def load_clip_g( for k in list(state_dict.keys()): if k.startswith(prefix): clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_g_path is None: + logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided") + return None + + # load clip_g + logger.info("Building CLIP-G") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) if clip_g_sd is None: - clip_g = None - else: - logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded ClipG: {info}") - clip_g.set_attn_mode(attn_mode) - return clip_g + logger.info(f"Loading state dict from {clip_g_path}") + clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(clip_g_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-G: {info}") + return clip def load_t5xxl( - state_dict: Dict, t5xxl_path: Optional[str], - attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): t5xxl_sd = None - if t5xxl_path: - logger.info(f"Loading t5xxl from {t5xxl_path}...") - t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: + if state_dict is not None: if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: # found t5xxl: remove prefix "text_encoders.t5xxl." logger.info("t5xxl is included in the checkpoint") @@ -144,29 +228,19 @@ def load_t5xxl( for k in list(state_dict.keys()): if k.startswith(prefix): t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + elif t5xxl_path is None: + logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided") + return None - if t5xxl_sd is None: - t5xxl = None - else: - logger.info("Building T5XXL") - - # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device - t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) - t5xxl.to(dtype=dtype) - - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded T5XXL: {info}") - t5xxl.set_attn_mode(attn_mode) - return t5xxl + return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd) def load_vae( - state_dict: Dict, vae_path: Optional[str], vae_dtype: Optional[Union[str, torch.dtype]], device: Optional[Union[str, torch.device]], disable_mmap: bool = False, + state_dict: Optional[Dict] = None, ): vae_sd = {} if vae_path: @@ -181,299 +255,15 @@ def load_vae( vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) logger.info("Building VAE") - vae = sd3_models.SDVAE() + vae = sd3_models.SDVAE(vae_dtype, device) logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") - vae.to(device=device, dtype=vae_dtype) + vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype return vae -def load_models( - ckpt_path: str, - clip_l_path: str, - clip_g_path: str, - t5xxl_path: str, - vae_path: str, - attn_mode: str, - device: Union[str, torch.device], - weight_dtype: Optional[Union[str, torch.dtype]] = None, - disable_mmap: bool = False, - clip_dtype: Optional[Union[str, torch.dtype]] = None, - t5xxl_device: Optional[Union[str, torch.device]] = None, - t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, - vae_dtype: Optional[Union[str, torch.dtype]] = None, -): - """ - Load SD3 models from checkpoint files. - - Args: - ckpt_path: Path to the SD3 checkpoint file. - clip_l_path: Path to the clip_l checkpoint file. - clip_g_path: Path to the clip_g checkpoint file. - t5xxl_path: Path to the t5xxl checkpoint file. - vae_path: Path to the VAE checkpoint file. - attn_mode: Attention mode for MMDiT model. - device: Device for MMDiT model. - weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. - disable_mmap: Disable memory mapping when loading state dict. - clip_dtype: Dtype for Clip models, or None to use default dtype. - t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. - t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. - vae_dtype: Dtype for VAE model, or None to use default dtype. - - Returns: - Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. - """ - - # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. - # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. - # Therefore, we need clip_dtype and t5xxl_dtype. - - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): - if disable_mmap: - return safetensors.torch.load(open(path, "rb").read()) - else: - try: - return load_file(path, device=dvc) - except: - return load_file(path) # prevent device invalid Error - - t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or weight_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 - vae_dtype = vae_dtype or weight_dtype or torch.float32 - - logger.info(f"Loading SD3 models from {ckpt_path}...") - state_dict = load_state_dict(ckpt_path) - - # load clip_l - clip_l_sd = None - if clip_l_path: - logger.info(f"Loading clip_l from {clip_l_path}...") - clip_l_sd = load_state_dict(clip_l_path) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - else: - if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_l: remove prefix "text_encoders.clip_l." - logger.info("clip_l is included in the checkpoint") - clip_l_sd = {} - prefix = "text_encoders.clip_l." - for k in list(state_dict.keys()): - if k.startswith(prefix): - clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) - - # load clip_g - clip_g_sd = None - if clip_g_path: - logger.info(f"Loading clip_g from {clip_g_path}...") - clip_g_sd = load_state_dict(clip_g_path) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - else: - if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_g: remove prefix "text_encoders.clip_g." - logger.info("clip_g is included in the checkpoint") - clip_g_sd = {} - prefix = "text_encoders.clip_g." - for k in list(state_dict.keys()): - if k.startswith(prefix): - clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) - - # load t5xxl - t5xxl_sd = None - if t5xxl_path: - logger.info(f"Loading t5xxl from {t5xxl_path}...") - t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: - if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: - # found t5xxl: remove prefix "text_encoders.t5xxl." - logger.info("t5xxl is included in the checkpoint") - t5xxl_sd = {} - prefix = "text_encoders.t5xxl." - for k in list(state_dict.keys()): - if k.startswith(prefix): - t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) - - # MMDiT and VAE - vae_sd = {} - if vae_path: - logger.info(f"Loading VAE from {vae_path}...") - vae_sd = load_state_dict(vae_path) - else: - # remove prefix "first_stage_model." - vae_sd = {} - vae_prefix = "first_stage_model." - for k in list(state_dict.keys()): - if k.startswith(vae_prefix): - vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) - - mmdit_prefix = "model.diffusion_model." - for k in list(state_dict.keys()): - if k.startswith(mmdit_prefix): - state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) - else: - state_dict.pop(k) # remove other keys - - # load MMDiT - logger.info("Building MMDit") - with init_empty_weights(): - mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) - - logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) - logger.info(f"Loaded MMDiT: {info}") - - # load ClipG and ClipL - if clip_l_sd is None: - clip_l = None - else: - logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded ClipL: {info}") - clip_l.set_attn_mode(attn_mode) - - if clip_g_sd is None: - clip_g = None - else: - logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded ClipG: {info}") - clip_g.set_attn_mode(attn_mode) - - # load T5XXL - if t5xxl_sd is None: - t5xxl = None - else: - logger.info("Building T5XXL") - t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd) - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded T5XXL: {info}") - t5xxl.set_attn_mode(attn_mode) - - # load VAE - logger.info("Building VAE") - vae = sd3_models.SDVAE() - logger.info("Loading state dict...") - info = vae.load_state_dict(vae_sd) - logger.info(f"Loaded VAE: {info}") - vae.to(device=device, dtype=vae_dtype) - - return mmdit, clip_l, clip_g, t5xxl, vae - - # endregion -# region utils - - -def get_cond( - prompt: str, - tokenizer: sd3_models.SD3Tokenizer, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt) - print(t5_tokens) - return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype) - - -def get_cond_from_tokens( - l_tokens, - g_tokens, - t5_tokens, - clip_l: sd3_models.SDClipModel, - clip_g: sd3_models.SDXLClipG, - t5xxl: Optional[sd3_models.T5XXLModel] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - l_out, l_pooled = clip_l.encode_token_weights(l_tokens) - g_out, g_pooled = clip_g.encode_token_weights(g_tokens) - lg_out = torch.cat([l_out, g_out], dim=-1) - lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) - if device is not None: - lg_out = lg_out.to(device=device) - l_pooled = l_pooled.to(device=device) - g_pooled = g_pooled.to(device=device) - if dtype is not None: - lg_out = lg_out.to(dtype=dtype) - l_pooled = l_pooled.to(dtype=dtype) - g_pooled = g_pooled.to(dtype=dtype) - - # t5xxl may be in another device (eg. cpu) - if t5_tokens is None: - t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) - else: - t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None - if device is not None: - t5_out = t5_out.to(device=device) - if dtype is not None: - t5_out = t5_out.to(dtype=dtype) - - # return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1) - return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1) - - -# used if other sd3 models is available -r""" -def get_sd3_configs(state_dict: Dict): - # Important configuration values can be quickly determined by checking shapes in the source file - # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) - # prefix = "model.diffusion_model." - prefix = "" - - patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2] - depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64 - num_patches = state_dict[prefix + "pos_embed"].shape[1] - pos_embed_max_size = round(math.sqrt(num_patches)) - adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1] - context_shape = state_dict[prefix + "context_embedder.weight"].shape - context_embedder_config = { - "target": "torch.nn.Linear", - "params": {"in_features": context_shape[1], "out_features": context_shape[0]}, - } - return { - "patch_size": patch_size, - "depth": depth, - "num_patches": num_patches, - "pos_embed_max_size": pos_embed_max_size, - "adm_in_channels": adm_in_channels, - "context_embedder": context_embedder_config, - } - - -def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"): - "" - Doesn't load state dict. - "" - sd3_configs = get_sd3_configs(state_dict) - - mmdit = sd3_models.MMDiT( - input_size=None, - pos_embed_max_size=sd3_configs["pos_embed_max_size"], - patch_size=sd3_configs["patch_size"], - in_channels=16, - adm_in_channels=sd3_configs["adm_in_channels"], - depth=sd3_configs["depth"], - mlp_ratio=4, - qk_norm=None, - num_patches=sd3_configs["num_patches"], - context_size=4096, - attn_mode=attn_mode, - ) - return mmdit -""" class ModelSamplingDiscreteFlow: @@ -509,6 +299,3 @@ class ModelSamplingDiscreteFlow: # assert max_denoise is False, "max_denoise not implemented" # max_denoise is always True, I'm not sure why it's there return sigma * noise + (1.0 - sigma) * latent_image - - -# endregion diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 9fde0208..dd08cf00 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -3,7 +3,7 @@ import glob from typing import Any, List, Optional, Tuple, Union import torch import numpy as np -from transformers import CLIPTokenizer, T5TokenizerFast +from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel from library import sd3_utils, train_util from library import sd3_models @@ -48,45 +48,79 @@ class Sd3TokenizeStrategy(TokenizeStrategy): class Sd3TextEncodingStrategy(TextEncodingStrategy): - def __init__(self) -> None: - pass + def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], - apply_lg_attn_mask: bool = False, - apply_t5_attn_mask: bool = False, + apply_lg_attn_mask: Optional[bool] = False, + apply_t5_attn_mask: Optional[bool] = False, ) -> List[torch.Tensor]: """ returned embeddings are not masked """ clip_l, clip_g, t5xxl = models + clip_l: CLIPTextModel + clip_g: CLIPTextModelWithProjection + t5xxl: T5EncoderModel + + if apply_lg_attn_mask is None: + apply_lg_attn_mask = self.apply_lg_attn_mask + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask l_tokens, g_tokens, t5_tokens = tokens[:3] - l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None] + + if len(tokens) > 3: + l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] + if not apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + else: + l_attn_mask = l_attn_mask.to(clip_l.device) + g_attn_mask = g_attn_mask.to(clip_g.device) + if not apply_t5_attn_mask: + t5_attn_mask = None + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) + else: + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None + if l_tokens is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None + lg_pooled = None else: - assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - l_out, l_pooled = clip_l(l_tokens) - g_out, g_pooled = clip_g(g_tokens) - if apply_lg_attn_mask: - l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1) - g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1) - lg_out = torch.cat([l_out, g_out], dim=-1) + with torch.no_grad(): + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) + l_pooled = prompt_embeds[0] + l_out = prompt_embeds.hidden_states[-2] + + prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) + g_pooled = prompt_embeds[0] + g_out = prompt_embeds.hidden_states[-2] + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is not None and t5_tokens is not None: - t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096] - if apply_t5_attn_mask: - t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + with torch.no_grad(): + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) else: t5_out = None - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None - return [lg_out, t5_out, lg_pooled] + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor @@ -132,39 +166,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used return False - # t5xxl is optional + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e return True - def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray: - l_out = lg_out[..., :768] - g_out = lg_out[..., 768:] # 1280 - l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask. - g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask. - return np.concatenate([l_out, g_out], axis=-1) - - def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray: - return t5_out * np.expand_dims(t5_attn_mask, -1) - def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: data = np.load(npz_path) lg_out = data["lg_out"] lg_pooled = data["lg_pooled"] - t5_out = data["t5_out"] if "t5_out" in data else None + t5_out = data["t5_out"] - if self.apply_lg_attn_mask: - l_attn_mask = data["clip_l_attn_mask"] - g_attn_mask = data["clip_g_attn_mask"] - lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask) + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + t5_attn_mask = data["t5_attn_mask"] - if self.apply_t5_attn_mask and t5_out is not None: - t5_attn_mask = data["t5_attn_mask"] - t5_out = self.mask_t5_attn(t5_out, t5_attn_mask) - - return [lg_out, t5_out, lg_pooled] + # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def cache_batch_outputs( self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List @@ -174,7 +207,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens( + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask ) @@ -182,38 +215,41 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): lg_out = lg_out.float() if lg_pooled.dtype == torch.bfloat16: lg_pooled = lg_pooled.float() - if t5_out is not None and t5_out.dtype == torch.bfloat16: + if t5_out.dtype == torch.bfloat16: t5_out = t5_out.float() lg_out = lg_out.cpu().numpy() lg_pooled = lg_pooled.cpu().numpy() - if t5_out is not None: - t5_out = t5_out.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = tokens_and_masks[3].cpu().numpy() + g_attn_mask = tokens_and_masks[4].cpu().numpy() + t5_attn_mask = tokens_and_masks[5].cpu().numpy() for i, info in enumerate(infos): lg_out_i = lg_out[i] - t5_out_i = t5_out[i] if t5_out is not None else None + t5_out_i = t5_out[i] lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask if self.cache_to_disk: - clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6] - clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy() - clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy() - t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None - kwargs = {} - if t5_out is not None: - kwargs["t5_out"] = t5_out_i np.savez( info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, - clip_l_attn_mask=clip_l_attn_mask_i, - clip_g_attn_mask=clip_g_attn_mask_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, t5_attn_mask=t5_attn_mask_i, - **kwargs, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, ) else: - info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i) + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) class Sd3LatentsCachingStrategy(LatentsCachingStrategy): @@ -246,41 +282,3 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) - - -if __name__ == "__main__": - # test code for Sd3TokenizeStrategy - # tokenizer = sd3_models.SD3Tokenizer() - strategy = Sd3TokenizeStrategy(256) - text = "hello world" - - l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) - # print(l_tokens.shape) - print(l_tokens) - print(g_tokens) - print(t5_tokens) - - texts = ["hello world", "the quick brown fox jumps over the lazy dog"] - l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") - t5_tokens_2 = strategy.t5xxl( - texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - print(l_tokens_2) - print(g_tokens_2) - print(t5_tokens_2) - - # compare - print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) - print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) - print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) - - text = ",".join(["hello world! this is long text"] * 50) - l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) - print(l_tokens) - print(g_tokens) - print(t5_tokens) - - print(f"model max length l: {strategy.clip_l.model_max_length}") - print(f"model max length g: {strategy.clip_g.model_max_length}") - print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/train_util.py b/library/train_util.py index 462c7a9a..9ea1eec0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5967,6 +5967,37 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + def sample_images_common( pipe_class, accelerator: Accelerator, diff --git a/library/utils.py b/library/utils.py index 8a0c782c..ca0f904d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest import cv2 from PIL import Image import numpy as np +from safetensors.torch import load_file def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() +# region Logging + + def add_logging_arguments(parser): parser.add_argument( "--console_log_level", @@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +# endregion + +# region PyTorch utils + + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ Convert a string to a torch.dtype @@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen: # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +) -> dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + + +# endregion + +# region Image utils + + def pil_resize(image, size, interpolation=Image.LANCZOS): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False @@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 +# endregion + # TODO make inf_utils.py - - # region Gradual Latent hires fix diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 630da7e0..d099fe18 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -12,6 +12,7 @@ import torch from safetensors.torch import safe_open, load_file from tqdm import tqdm from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel from library.device_utils import init_ipex, get_preferred_device @@ -25,11 +26,14 @@ import logging logger = logging.getLogger(__name__) from library import sd3_models, sd3_utils, strategy_sd3 +from library.utils import load_safetensors -def get_noise(seed, latent): - generator = torch.manual_seed(seed) - return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype) +def get_noise(seed, latent, device="cpu"): + # generator = torch.manual_seed(seed) + generator = torch.Generator(device) + generator.manual_seed(seed) + return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device) def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): @@ -59,7 +63,7 @@ def do_sample( neg_cond: Tuple[torch.Tensor, torch.Tensor], mmdit: sd3_models.MMDiT, steps: int, - guidance_scale: float, + cfg_scale: float, dtype: torch.dtype, device: str, ): @@ -71,7 +75,7 @@ def do_sample( latent = latent.to(dtype).to(device) - noise = get_noise(seed, latent).to(device) + noise = get_noise(seed, latent, device) model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 @@ -105,7 +109,7 @@ def do_sample( batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) pos_out, neg_out = batched.chunk(2) - denoised = neg_out + (pos_out - neg_out) * guidance_scale + denoised = neg_out + (pos_out - neg_out) * cfg_scale # print(denoised.shape) # d = to_d(x, sigma_hat, denoised) @@ -122,230 +126,68 @@ def do_sample( x = x.to(dtype) latent = x - scale_factor = 1.5305 - shift_factor = 0.0609 - # def process_out(self, latent): - # return (latent / self.scale_factor) + self.shift_factor - latent = (latent / scale_factor) + shift_factor + latent = vae.process_out(latent) return latent -if __name__ == "__main__": - target_height = 1024 - target_width = 1024 - - # steps = 50 # 28 # 50 - guidance_scale = 5 - # seed = 1 # None # 1 - - device = get_preferred_device() - - parser = argparse.ArgumentParser() - parser.add_argument("--ckpt_path", type=str, required=True) - parser.add_argument("--clip_g", type=str, required=False) - parser.add_argument("--clip_l", type=str, required=False) - parser.add_argument("--t5xxl", type=str, required=False) - parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77") - parser.add_argument("--apply_lg_attn_mask", action="store_true") - parser.add_argument("--apply_t5_attn_mask", action="store_true") - parser.add_argument("--prompt", type=str, default="A photo of a cat") - # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders - parser.add_argument("--negative_prompt", type=str, default="") - parser.add_argument("--output_dir", type=str, default=".") - parser.add_argument("--do_not_use_t5xxl", action="store_true") - parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") - parser.add_argument("--fp16", action="store_true") - parser.add_argument("--bf16", action="store_true") - parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--steps", type=int, default=50) - # parser.add_argument( - # "--lora_weights", - # type=str, - # nargs="*", - # default=[], - # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", - # ) - # parser.add_argument("--interactive", action="store_true") - args = parser.parse_args() - - seed = args.seed - steps = args.steps - - sd3_dtype = torch.float32 - if args.fp16: - sd3_dtype = torch.float16 - elif args.bf16: - sd3_dtype = torch.bfloat16 - - # TODO test with separated safetenors files for each model - - # load state dict - logger.info(f"Loading SD3 models from {args.ckpt_path}...") - state_dict = load_file(args.ckpt_path) - - if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_g: remove prefix "text_encoders.clip_g." - logger.info("clip_g is included in the checkpoint") - clip_g_sd = {} - prefix = "text_encoders.clip_g." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info(f"Lodaing clip_g from {args.clip_g}...") - clip_g_sd = load_file(args.clip_g) - for key in list(clip_g_sd.keys()): - clip_g_sd["transformer." + key] = clip_g_sd.pop(key) - - if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: - # found clip_l: remove prefix "text_encoders.clip_l." - logger.info("clip_l is included in the checkpoint") - clip_l_sd = {} - prefix = "text_encoders.clip_l." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info(f"Lodaing clip_l from {args.clip_l}...") - clip_l_sd = load_file(args.clip_l) - for key in list(clip_l_sd.keys()): - clip_l_sd["transformer." + key] = clip_l_sd.pop(key) - - if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: - # found t5xxl: remove prefix "text_encoders.t5xxl." - logger.info("t5xxl is included in the checkpoint") - if not args.do_not_use_t5xxl: - t5xxl_sd = {} - prefix = "text_encoders.t5xxl." - for k, v in list(state_dict.items()): - if k.startswith(prefix): - t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) - else: - logger.info("but not used") - for key in list(state_dict.keys()): - if key.startswith("text_encoders.t5xxl."): - state_dict.pop(key) - t5xxl_sd = None - elif args.t5xxl: - assert not args.do_not_use_t5xxl, "t5xxl is not used but specified" - logger.info(f"Lodaing t5xxl from {args.t5xxl}...") - t5xxl_sd = load_file(args.t5xxl) - for key in list(t5xxl_sd.keys()): - t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) - else: - logger.info("t5xxl is not used") - t5xxl_sd = None - - use_t5xxl = t5xxl_sd is not None - - # MMDiT and VAE - vae_sd = {} - vae_prefix = "first_stage_model." - mmdit_prefix = "model.diffusion_model." - for k, v in list(state_dict.items()): - if k.startswith(vae_prefix): - vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) - elif k.startswith(mmdit_prefix): - state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k) - - # load tokenizers - logger.info("Loading tokenizers...") - tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) - - # load models - # logger.info("Create MMDiT from SD3 checkpoint...") - # mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict) - logger.info("Create MMDiT") - mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode) - - logger.info("Loading state dict...") - info = mmdit.load_state_dict(state_dict) - logger.info(f"Loaded MMDiT: {info}") - - logger.info(f"Move MMDiT to {device} and {sd3_dtype}...") - mmdit.to(device, dtype=sd3_dtype) - mmdit.eval() - - # load VAE - logger.info("Create VAE") - vae = sd3_models.SDVAE() - logger.info("Loading state dict...") - info = vae.load_state_dict(vae_sd) - logger.info(f"Loaded VAE: {info}") - - logger.info(f"Move VAE to {device} and {sd3_dtype}...") - vae.to(device, dtype=sd3_dtype) - vae.eval() - - # load text encoders - logger.info("Create clip_l") - clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd) - - logger.info("Loading state dict...") - info = clip_l.load_state_dict(clip_l_sd) - logger.info(f"Loaded clip_l: {info}") - - logger.info(f"Move clip_l to {device} and {sd3_dtype}...") - clip_l.to(device, dtype=sd3_dtype) - clip_l.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - clip_l.set_attn_mode(args.attn_mode) - - logger.info("Create clip_g") - clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd) - - logger.info("Loading state dict...") - info = clip_g.load_state_dict(clip_g_sd) - logger.info(f"Loaded clip_g: {info}") - - logger.info(f"Move clip_g to {device} and {sd3_dtype}...") - clip_g.to(device, dtype=sd3_dtype) - clip_g.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - clip_g.set_attn_mode(args.attn_mode) - - if use_t5xxl: - logger.info("Create t5xxl") - t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd) - - logger.info("Loading state dict...") - info = t5xxl.load_state_dict(t5xxl_sd) - logger.info(f"Loaded t5xxl: {info}") - - logger.info(f"Move t5xxl to {device} and {sd3_dtype}...") - t5xxl.to(device, dtype=sd3_dtype) - # t5xxl.to("cpu", dtype=torch.float32) # run on CPU - t5xxl.eval() - logger.info(f"Set attn_mode to {args.attn_mode}...") - t5xxl.set_attn_mode(args.attn_mode) - else: - t5xxl = None - +def generate_image( + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: CLIPTextModelWithProjection, + clip_g: CLIPTextModelWithProjection, + t5xxl: T5EncoderModel, + steps: int, + prompt: str, + seed: int, + target_width: int, + target_height: int, + device: str, + negative_prompt: str, + cfg_scale: float, +): # prepare embeddings logger.info("Encoding prompts...") - encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - tokens_and_masks = tokenize_strategy.tokenize(args.prompt) - lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask - ) - cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + # TODO support one-by-one offloading + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) - tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt) - lg_out, t5_out, pooled = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask - ) - neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + with torch.no_grad(): + tokens_and_masks = tokenize_strategy.tokenize(prompt) + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # attn masks are not used currently + + if args.offload: + clip_l.to("cpu") + clip_g.to("cpu") + t5xxl.to("cpu") # generate image logger.info("Generating image...") - latent_sampled = do_sample( - target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device - ) + mmdit.to(device) + latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device) + if args.offload: + mmdit.to("cpu") # latent to image + vae.to(device) with torch.no_grad(): image = vae.decode(latent_sampled) + + if args.offload: + vae.to("cpu") + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) @@ -359,3 +201,179 @@ if __name__ == "__main__": out_image.save(output_path) logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + # cfg_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--cfg_scale", type=float, default=5.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument("--output_dir", type=str, default=".") + # parser.add_argument("--do_not_use_t5xxl", action="store_true") + # parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + # parser.add_argument( + # "--lora_weights", + # type=str, + # nargs="*", + # default=[], + # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", + # ) + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + loading_device = "cpu" if args.offload else device + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + # state_dict = load_file(args.ckpt_path) + state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype) + + # load text encoders + clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict) + + # MMDiT and VAE + vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict) + mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device) + + clip_l.to(sd3_dtype) + clip_g.to(sd3_dtype) + t5xxl.to(sd3_dtype) + vae.to(sd3_dtype) + mmdit.to(sd3_dtype) + if not args.offload: + # make sure to move to the device: some tensors are created in the constructor on the CPU + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + vae.to(device) + mmdit.to(device) + + clip_l.eval() + clip_g.eval() + t5xxl.eval() + mmdit.eval() + vae.eval() + + # load tokenizers + logger.info("Loading tokenizers...") + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + + if not args.interactive: + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + args.steps, + args.prompt, + args.seed, + args.width, + args.height, + device, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = args.width + height = args.height + steps = None + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d " + " --n , `--n -` for empty negative prompt" + "Options are kept for the next prompt. Current options:" + f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}" + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + # elif opt.startswith("m"): + # mutipliers = opt[1:].strip().split(",") + # if len(mutipliers) != len(lora_models): + # logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + # continue + # for i, lora_model in enumerate(lora_models): + # lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + steps if steps is not None else args.steps, + prompt, + seed if seed is not None else args.seed, + width, + height, + device, + negative_prompt if negative_prompt is not None else args.negative_prompt, + cfg_scale, + ) + + logger.info("Done!") diff --git a/sd3_train.py b/sd3_train.py index ef18c32c..6336b4cf 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -1,6 +1,7 @@ # training with captions import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math import os @@ -11,6 +12,7 @@ import toml from tqdm import tqdm import torch +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -38,7 +40,7 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments # from library.custom_train_functions import ( # apply_snr_weight, @@ -61,23 +63,13 @@ def train(args): if not args.skip_cache_check: args.skip_cache_check = args.skip_latents_validity_check - assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - # # training text encoder is not supported - # assert ( - # not args.train_text_encoder - # ), "training text encoder is not supported currently / text encoderの学習は現在サポートされていません" - - # # training without text encoder cache is not supported: because T5XXL must be cached - # assert ( - # args.cache_text_encoder_outputs - # ), "training without text encoder cache is not supported currently / text encoderのキャッシュなしの学習は現在サポートされていません" - assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" @@ -90,13 +82,13 @@ def train(args): ) args.cache_text_encoder_outputs = True - # if args.block_lr: - # block_lrs = [float(lr) for lr in args.block_lr.split(",")] - # assert ( - # len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR - # ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" - # else: - # block_lrs = None + if args.train_t5xxl: + assert ( + args.train_text_encoder + ), "when training T5XXL, text encoder (CLIP-L/G) must be trained / T5XXLを学習するときはtext encoder (CLIP-L/G)も学習する必要があります" + assert ( + not args.cache_text_encoder_outputs + ), "when training T5XXL, t5xxl output must not be cached / T5XXLを学習するときはt5xxlの出力をキャッシュできません" cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -111,11 +103,6 @@ def train(args): ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) - # load tokenizer and prepare tokenize strategy - sd3_tokenizer = sd3_models.SD3Tokenizer(t5xxl_max_length=args.t5xxl_max_token_length) - sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) - strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) - # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -156,10 +143,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[sd3_tokenizer]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [sd3_tokenizer]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -205,41 +192,152 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = weight_dtype # torch.float32 if args.no_half_vae else weight_dtype # SD3 VAE works with fp16 - - t5xxl_dtype = weight_dtype - if args.t5xxl_dtype is not None: - if args.t5xxl_dtype == "fp16": - t5xxl_dtype = torch.float16 - elif args.t5xxl_dtype == "bf16": - t5xxl_dtype = torch.bfloat16 - elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": - t5xxl_dtype = torch.float32 - else: - raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") - t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - - clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む - attn_mode = "xformers" if args.xformers else "torch" - assert ( - attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + # t5xxl_dtype = weight_dtype + # if args.t5xxl_dtype is not None: + # if args.t5xxl_dtype == "fp16": + # t5xxl_dtype = torch.float16 + # elif args.t5xxl_dtype == "bf16": + # t5xxl_dtype = torch.bfloat16 + # elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": + # t5xxl_dtype = torch.float32 + # else: + # raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") + # t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + # clip_dtype = weight_dtype # if not args.train_text_encoder else None - # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. - logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") - device_to_load = accelerator.device if args.lowram else "cpu" - sd3_state_dict = sd3_utils.load_safetensors( - args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors - ) + # if clip_l is not specified, the checkpoint must contain clip_l, so we load state dict here + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + if args.clip_l is None: + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + else: + sd3_state_dict = None + + # load tokenizer and prepare tokenize strategy + if args.t5xxl_max_token_length is None: + t5xxl_max_token_length = 256 # default value for T5XXL + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) + + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + # clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" + + # prepare text encoding strategy + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # 学習を準備する:モデルを適切な状態にする + train_clip = False + train_t5xxl = False + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + if args.train_t5xxl: + t5xxl.gradient_checkpointing_enable() + + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + lr_t5xxl = args.learning_rate_te3 if args.learning_rate_te3 is not None else args.learning_rate # 0 means not train + train_clip = lr_te1 != 0 or lr_te2 != 0 + train_t5xxl = lr_t5xxl != 0 and args.train_t5xxl + + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(train_clip) + clip_g.requires_grad_(train_clip) + t5xxl.requires_grad_(train_t5xxl) + else: + print("disable text encoder training") + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + t5xxl.requires_grad_(False) + lr_te1 = 0 + lr_te2 = 0 + lr_t5xxl = 0 + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + t5xxl.to(accelerator.device) + clip_l.eval() + clip_g.eval() + t5xxl.eval() + + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + if args.use_t5xxl_cache_only: + clip_l = None + clip_g = None + t5xxl = None + + clean_memory_on_device(accelerator.device) # load VAE for caching latents - vae: sd3_models.SDVAE = None + if sd3_state_dict is None: + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + + vae = sd3_utils.load_vae(args.vae, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) if cache_latents: - vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) - vae.to(accelerator.device, dtype=vae_dtype) + # vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.to(accelerator.device, dtype=weight_dtype) vae.requires_grad_(False) vae.eval() @@ -250,127 +348,36 @@ def train(args): accelerator.wait_for_everyone() - # load clip_l, clip_g, t5xxl for caching text encoder outputs - # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype - # ) - clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) - clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - - t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - - # should be deleted after caching text encoder outputs when not training text encoder - # this strategy should not be used other than this process - text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() - strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - - # 学習を準備する:モデルを適切な状態にする - train_clip_l = False - train_clip_g = False - train_t5xxl = False - - if args.train_text_encoder: - accelerator.print("enable text encoder training") - if args.gradient_checkpointing: - clip_l.gradient_checkpointing_enable() - clip_g.gradient_checkpointing_enable() - lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_clip_l = lr_te1 != 0 - train_clip_g = lr_te2 != 0 - - if not train_clip_l: - clip_l.to(weight_dtype) - if not train_clip_g: - clip_g.to(weight_dtype) - clip_l.requires_grad_(train_clip_l) - clip_g.requires_grad_(train_clip_g) - clip_l.train(train_clip_l) - clip_g.train(train_clip_g) - else: - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) - clip_l.requires_grad_(False) - clip_g.requires_grad_(False) - clip_l.eval() - clip_g.eval() - - if t5xxl is not None: - t5xxl.to(t5xxl_dtype) - t5xxl.requires_grad_(False) - t5xxl.eval() - - # cache text encoder outputs - sample_prompts_te_outputs = None - if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad here - clip_l.to(accelerator.device) - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(t5xxl_device) - - text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - train_clip_g or train_clip_l or args.use_t5xxl_cache_only, - args.apply_lg_attn_mask, - args.apply_t5_attn_mask, - ) - strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) - - clip_l.to(accelerator.device, dtype=weight_dtype) - clip_g.to(accelerator.device, dtype=weight_dtype) - if t5xxl is not None: - t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) - - with accelerator.autocast(): - train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) - - # cache sample prompt's embeddings to free text encoder's memory - if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - prompts = sd3_train_utils.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs - with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: - if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_list = sd3_tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - sd3_tokenize_strategy, - [clip_l, clip_g, t5xxl], - tokens_list, - args.apply_lg_attn_mask, - args.apply_t5_attn_mask, - ) - - accelerator.wait_for_everyone() - # load MMDIT - # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). - # by loading with model_dtype, we can reduce memory usage. - model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) - mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + mmdit = sd3_utils.load_mmdit( + sd3_state_dict, + model_dtype, + "cpu", + ) + + # attn_mode = "xformers" if args.xformers else "torch" + # assert ( + # attn_mode == "torch" + # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() train_mmdit = args.learning_rate != 0 mmdit.requires_grad_(train_mmdit) if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared if not cache_latents: - # load VAE here if not cached - vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + # move to accelerator device vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) + vae.to(accelerator.device, dtype=weight_dtype) mmdit.requires_grad_(train_mmdit) if not train_mmdit: @@ -394,19 +401,24 @@ def train(args): training_models = [] params_to_optimize = [] - # if train_unet: + param_names = [] training_models.append(mmdit) - # if block_lrs is None: params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) - # else: - # params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs)) + param_names.append([n for n, _ in mmdit.named_parameters()]) - # if train_clip_l: - # training_models.append(clip_l) - # params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) - # if train_clip_g: - # training_models.append(clip_g) - # params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + if train_clip: + if lr_te1 > 0: + training_models.append(clip_l) + params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + param_names.append([n for n, _ in clip_l.named_parameters()]) + if lr_te2 > 0: + training_models.append(clip_g) + params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + param_names.append([n for n, _ in clip_g.named_parameters()]) + if train_t5xxl: + training_models.append(t5xxl) + params_to_optimize.append({"params": list(t5xxl.parameters()), "lr": args.learning_rate_te3 or args.learning_rate}) + param_names.append([n for n, _ in t5xxl.named_parameters()]) # calculate number of trainable parameters n_params = 0 @@ -414,47 +426,49 @@ def train(args): for p in group["params"]: n_params += p.numel() - accelerator.print(f"train mmdit: {train_mmdit}") # , clip_l: {train_clip_l}, clip_g: {train_clip_g}") + accelerator.print(f"train mmdit: {train_mmdit} , clip:{train_clip}, t5xxl:{train_t5xxl}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html - # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. # This balances memory usage and management complexity. - # calculate total number of parameters - n_total_params = sum(len(params["params"]) for params in params_to_optimize) - params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - - # split params into groups, keeping the learning rate the same for all params in a group - # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + # split params into groups for mmdit. clip_l, clip_g, t5xxl are in each group grouped_params = [] - param_group = [] - param_group_lr = -1 - for group in params_to_optimize: - lr = group["lr"] - for p in group["params"]: - # if the learning rate is different for different params, start a new group - if lr != param_group_lr: - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = lr + param_group = {} + group = params_to_optimize[0] + named_parameters = list(mmdit.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # joint or other + if np[0].startswith("joint_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "joint" + else: + block_idx = -1 - param_group.append(p) + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) - # if the group has enough parameters, start a new group - if len(param_group) == params_per_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) - param_group = [] - param_group_lr = -1 + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) - if param_group: - grouped_params.append({"params": param_group, "lr": param_group_lr}) + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + grouped_params.extend(params_to_optimize[1:]) # add clip_l, clip_g, t5xxl if they are trained # prepare optimizers for each group optimizers = [] @@ -463,10 +477,15 @@ def train(args): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset @@ -497,7 +516,7 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code @@ -511,18 +530,22 @@ def train(args): ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") mmdit.to(weight_dtype) - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) if t5xxl is not None: - t5xxl.to(weight_dtype) # TODO check works with fp16 or not + t5xxl.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") mmdit.to(weight_dtype) - clip_l.to(weight_dtype) - clip_g.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) if t5xxl is not None: t5xxl.to(weight_dtype) @@ -533,14 +556,7 @@ def train(args): # clip_l.text_model.final_layer_norm.requires_grad_(False) # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: + if not args.cache_text_encoder_outputs: # make sure Text Encoders are on GPU # TODO support CPU for text encoders clip_l.to(accelerator.device) @@ -548,18 +564,11 @@ def train(args): if t5xxl is not None: t5xxl.to(accelerator.device) - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory clean_memory_on_device(accelerator.device) if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( - args, - mmdit=mmdit, - clip_l=clip_l if train_clip_l else None, - clip_g=clip_g if train_clip_g else None, + args, mmdit=mmdit, clip_l=clip_l if train_clip else None, clip_g=clip_g if train_clip else None ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,10 +580,11 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: mmdit = accelerator.prepare(mmdit) - if train_clip_l: + if train_clip: clip_l = accelerator.prepare(clip_l) - if train_clip_g: clip_g = accelerator.prepare(clip_g) + if train_t5xxl: + t5xxl = accelerator.prepare(t5xxl) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -586,24 +596,110 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + # memory efficient block swapping + + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, device): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): + # print(f"Backward: Move block {bidx_to_cpu} to CPU") + block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) + torch.cuda.synchronize() + return bidx_to_cpu, bidx_to_cuda + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + futures[block_idx_to_cuda] = thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda, device + ) + + def wait_blocks_move(block_idx, futures): + if block_idx not in futures: + return + future = futures.pop(block_idx) + future.result() + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group in optimizer.param_groups: - for parameter in param_group["params"]: + + blocks_to_swap = args.blocks_to_swap + num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) + handled_block_indices = set() + + n = 1 # only asynchronous purpose, no need to increase this number + # n = 2 + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: + grad_hook = None - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + if blocks_to_swap: + is_block = param_name.startswith("double_blocks") + if is_block: + block_idx = int(param_name.split(".")[1]) + if block_idx not in handled_block_indices: + # swap following (already backpropagated) block + handled_block_indices.add(block_idx) - parameter.register_post_accumulate_grad_hook(__grad_hook) + # if n blocks were already backpropagated + num_blocks_propagated = num_blocks - block_idx - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: + block_idx_to_cpu = num_blocks - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_idx - 1 - elif args.fused_optimizer_groups: + # create swap hook + def create_swap_grad_hook( + bidx_to_cpu, bidx_to_cuda, bidx_to_wait, bidx: int, swpng: bool, wtng: bool + ): + def __grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + if swpng: + submit_move_blocks( + futures, + thread_pool, + bidx_to_cpu, + bidx_to_cuda, + mmdit.joint_blocks, + accelerator.device, + ) + if wtng: + wait_blocks_move(bidx_to_wait, futures) + + return __grad_hook + + grad_hook = create_swap_grad_hook( + block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, block_idx, swapping, waiting + ) + + if grad_hook is None: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + grad_hook = __grad_hook + + parameter.register_post_accumulate_grad_hook(grad_hook) + + elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) @@ -618,22 +714,59 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + blocks_to_swap = args.blocks_to_swap + num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) + + n = 1 # only asynchronous purpose, no need to increase this number + # n = max(1, os.cpu_count() // 2) + thread_pool = ThreadPoolExecutor(max_workers=n) + futures = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + block_type, block_idx = block_types_and_indices[opt_idx] - def optimizer_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + def create_optimizer_hook(btype, bidx): + def optimizer_hook(parameter: torch.Tensor): + # print(f"optimizer_hook: {btype}, {bidx}") + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) - parameter.register_post_accumulate_grad_hook(optimizer_hook) + # swap blocks if necessary + if blocks_to_swap and btype == "joint": + num_blocks_propagated = num_blocks - bidx + + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap + waiting = bidx > 0 and bidx <= blocks_to_swap + + if swapping: + block_idx_to_cpu = num_blocks - num_blocks_propagated + block_idx_to_cuda = blocks_to_swap - num_blocks_propagated + # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") + submit_move_blocks( + futures, + thread_pool, + block_idx_to_cpu, + block_idx_to_cuda, + mmdit.joint_blocks, + accelerator.device, + ) + + if waiting: + block_idx_to_wait = bidx - 1 + wait_blocks_move(block_idx_to_wait, futures) + + return optimizer_hook + + parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -661,17 +794,9 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = DDPMScheduler( - # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - # ) - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - # if args.zero_terminal_snr: - # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: @@ -685,60 +810,13 @@ def train(args): ) # For --sample_at_first + optimizer_eval_fn() sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - # following function will be moved to sd3_train_utils - - def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): - sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) - schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) - timesteps = timesteps.to(accelerator.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma - - def compute_density_for_timestep_sampling( - weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None - ): - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "sigma_sqrt": - weighting = (sigmas**-2.0).float() - elif weighting_scheme == "cosmap": - bot = 1 - 2 * sigmas + 2 * sigmas**2 - weighting = 2 / (math.pi * bot) - else: - weighting = torch.ones_like(sigmas) - return weighting - loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -751,16 +829,16 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"].to(vae_dtype)).to(weight_dtype) + latents = vae.encode(batch["images"]) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -772,7 +850,7 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: - lg_out, t5_out, lg_pooled = text_encoder_outputs_list + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list if args.use_t5xxl_cache_only: lg_out = None lg_pooled = None @@ -781,7 +859,7 @@ def train(args): t5_out = None lg_pooled = None - if lg_out is None or (train_clip_l or train_clip_g): + if lg_out is None: # not cached or training, so get from text encoders input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): @@ -811,21 +889,10 @@ def train(args): noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image - # for weighting schemes where we sample timesteps non-uniformly - u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=bsz, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) - indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() - timesteps = noise_scheduler_copy.timesteps[indices].to(device=accelerator.device) - - # Add noise according to flow matching. - sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents # debug: NaN check for all inputs if torch.any(torch.isnan(noisy_model_input)): @@ -840,6 +907,7 @@ def train(args): # call model with accelerator.autocast(): + # TODO support attention mask model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. @@ -848,21 +916,34 @@ def train(args): # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss - weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = latents - # Compute regular loss. TODO simplify this - loss = torch.mean( - (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), - 1, + # # Compute regular loss. TODO simplify this + # loss = torch.mean( + # (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + # 1, + # ) + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + if weighting is not None: + loss = loss * weighting + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights loss = loss.mean() accelerator.backward(loss) - if not (args.fused_backward_pass or args.fused_optimizer_groups): + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -875,7 +956,7 @@ def train(args): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.fused_optimizer_groups: + if args.blockwise_fused_optimizers: for i in range(1, len(optimizers)): lr_schedulers[i].step() @@ -884,6 +965,7 @@ def train(args): progress_bar.update(1) global_step += 1 + optimizer_eval_fn() sd3_train_utils.sample_images( accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs ) @@ -900,12 +982,13 @@ def train(args): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(clip_l) if args.save_clip else None, - accelerator.unwrap_model(clip_g) if args.save_clip else None, - accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, - accelerator.unwrap_model(mmdit), + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) + optimizer_train_fn() current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if len(accelerator.trackers) > 0: @@ -928,6 +1011,7 @@ def train(args): accelerator.wait_for_everyone() + optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( @@ -938,10 +1022,10 @@ def train(args): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(clip_l) if args.save_clip else None, - accelerator.unwrap_model(clip_g) if args.save_clip else None, - accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, - accelerator.unwrap_model(mmdit), + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) @@ -958,6 +1042,7 @@ def train(args): t5xxl = accelerator.unwrap_model(t5xxl) accelerator.end_training() + optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) @@ -970,10 +1055,10 @@ def train(args): save_dtype, epoch, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, vae, ) logger.info("model saved.") @@ -991,13 +1076,13 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) parser.add_argument( "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" ) - # parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") parser.add_argument( "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" ) @@ -1018,19 +1103,24 @@ def setup_parser() -> argparse.ArgumentParser: help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", ) - # TE training is disabled temporarily - # parser.add_argument( - # "--learning_rate_te1", - # type=float, - # default=None, - # help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", - # ) - # parser.add_argument( - # "--learning_rate_te2", - # type=float, - # default=None, - # help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", - # ) + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) + parser.add_argument( + "--learning_rate_te3", + type=float, + default=None, + help="learning rate for text encoder 3 (T5-XXL) / text encoder 3 (T5-XXL)の学習率", + ) # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" @@ -1047,22 +1137,22 @@ def setup_parser() -> argparse.ArgumentParser: # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", # ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) parser.add_argument( "--fused_optimizer_groups", type=int, default=None, - help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + help="[DOES NOT WORK] number of optimizer groups for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizerグループ数", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--skip_cache_check", - action="store_true", - help="skip cache (latents and text encoder outputs) check / キャッシュ(latentsとtext encoder outputs)のチェックをスキップする", - ) parser.add_argument( "--num_last_block_to_freeze", type=int, From e3c43bda49ec8c5a5cb784e29f8610f1ebff0a66 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 20:35:47 +0900 Subject: [PATCH 258/356] reduce memory usage in sample image generation --- library/sd3_train_utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9282482d..af8ecf2c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -402,9 +402,6 @@ def sample_images( except Exception: pass - org_vae_device = vae.device # will be on cpu - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device - if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): @@ -450,8 +447,6 @@ def sample_images( if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) - clean_memory_on_device(accelerator.device) @@ -531,12 +526,19 @@ def sample_image_inference( neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image - latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) - latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + clean_memory_on_device(accelerator.device) + with accelerator.autocast(): + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) # latent to image - with torch.no_grad(): - image = vae.decode(latents) + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) From 0286114bd208717510b537d9acd940db48a158f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 21:28:42 +0900 Subject: [PATCH 259/356] support SD3.5L, fix final saving --- sd3_train.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 6336b4cf..d4ab13a3 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -321,7 +321,7 @@ def train(args): accelerator.wait_for_everyone() # now we can delete Text Encoders to free memory - if args.use_t5xxl_cache_only: + if not args.use_t5xxl_cache_only: clip_l = None clip_g = None t5xxl = None @@ -330,6 +330,7 @@ def train(args): # load VAE for caching latents if sd3_state_dict is None: + logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") sd3_state_dict = utils.load_safetensors( args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype ) @@ -360,11 +361,6 @@ def train(args): # attn_mode == "torch" # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. - logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") - device_to_load = accelerator.device if args.lowram else "cpu" - sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors) - if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() @@ -555,7 +551,7 @@ def train(args): # clip_l.text_model.encoder.layers[-1].requires_grad_(False) # clip_l.text_model.final_layer_norm.requires_grad_(False) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + # move Text Encoders to GPU if not caching outputs if not args.cache_text_encoder_outputs: # make sure Text Encoders are on GPU # TODO support CPU for text encoders @@ -817,6 +813,13 @@ def train(args): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) + # show model device and dtype + logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") + logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") + logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") + logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") + logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") + loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -1055,10 +1058,10 @@ def train(args): save_dtype, epoch, global_step, - accelerator.unwrap_model(clip_l) if train_clip else None, - accelerator.unwrap_model(clip_g) if train_clip else None, - accelerator.unwrap_model(t5xxl) if train_t5xxl else None, - accelerator.unwrap_model(mmdit) if train_mmdit else None, + clip_l if train_clip else None, + clip_g if train_clip else None, + t5xxl if train_t5xxl else None, + mmdit if train_mmdit else None, vae, ) logger.info("model saved.") @@ -1153,6 +1156,16 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks (~640MB) to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) parser.add_argument( "--num_last_block_to_freeze", type=int, From f8c5146d71b1c40b69d80b7ea18c21bbb66b84f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 24 Oct 2024 22:02:05 +0900 Subject: [PATCH 260/356] support block swap with fused_optimizer_pass --- library/sd3_models.py | 79 +++++++++++++++++++++++++++++++++++++++++-- sd3_train.py | 19 +++++++++-- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index c81aa479..e5c5887a 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -4,6 +4,7 @@ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! from ast import Tuple +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial import math @@ -17,6 +18,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from library.device_utils import clean_memory_on_device + from .utils import setup_logging setup_logging() @@ -848,6 +851,35 @@ class MMDiT(nn.Module): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + self.thread_pool = ThreadPoolExecutor(max_workers=n) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return + num_blocks = len(self.joint_blocks) + for i in range(num_blocks - self.blocks_to_swap): + self.joint_blocks[i].to(self.device) + for i in range(num_blocks - self.blocks_to_swap, num_blocks): + self.joint_blocks[i].to("cpu") + clean_memory_on_device(self.device) + def forward( self, x: torch.Tensor, @@ -881,8 +913,51 @@ class MMDiT(nn.Module): 1, ) - for block in self.joint_blocks: - context, x = block(context, x, c) + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + futures = {} + + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + block_to_cuda.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + block_to_cpu = self.joint_blocks[block_idx_to_cpu] + block_to_cuda = self.joint_blocks[block_idx_to_cuda] + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + + for block_idx, block in enumerate(self.joint_blocks): + wait_for_blocks_move(block_idx, futures) + + context, x = block(context, x, c) + + if block_idx < self.blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] diff --git a/sd3_train.py b/sd3_train.py index d4ab13a3..5e2efa6f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -369,6 +369,14 @@ def train(args): if not train_mmdit: mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared + # block swap + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap) + if not cache_latents: # move to accelerator device vae.requires_grad_(False) @@ -575,7 +583,9 @@ def train(args): else: # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: - mmdit = accelerator.prepare(mmdit) + mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage if train_clip: clip_l = accelerator.prepare(clip_l) clip_g = accelerator.prepare(clip_g) @@ -600,8 +610,10 @@ def train(args): block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) torch.cuda.empty_cache() + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) torch.cuda.synchronize() + # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") return bidx_to_cpu, bidx_to_cuda block_to_cpu = blocks[block_idx_to_cpu] @@ -639,7 +651,7 @@ def train(args): grad_hook = None if blocks_to_swap: - is_block = param_name.startswith("double_blocks") + is_block = param_name.startswith("joint_blocks") if is_block: block_idx = int(param_name.split(".")[1]) if block_idx not in handled_block_indices: @@ -805,6 +817,9 @@ def train(args): init_kwargs=init_kwargs, ) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + # For --sample_at_first optimizer_eval_fn() sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) From b1e6504007aca20d15155d5c9fe880fb5e0002b8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 25 Oct 2024 18:56:25 +0900 Subject: [PATCH 261/356] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index de5cddb9..ce28d004 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details. + - Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717) + - There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! From d2c549d7b2a9bb3e70b5af8539fd744b474a9607 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 25 Oct 2024 21:58:31 +0900 Subject: [PATCH 262/356] support SD3 LoRA --- library/sd3_models.py | 3 + library/sd3_train_utils.py | 111 +++-- library/sd3_utils.py | 2 +- networks/lora_sd3.py | 826 +++++++++++++++++++++++++++++++++++++ sd3_train.py | 30 +- sd3_train_network.py | 427 +++++++++++++++++++ train_network.py | 2 + 7 files changed, 1334 insertions(+), 67 deletions(-) create mode 100644 networks/lora_sd3.py create mode 100644 sd3_train_network.py diff --git a/library/sd3_models.py b/library/sd3_models.py index e5c5887a..5d09f74e 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -761,6 +761,9 @@ class MMDiT(nn.Module): self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) # self.initialize_weights() + self.blocks_to_swap = None + self.thread_pool: Optional[ThreadPoolExecutor] = None + @property def model_type(self): return self._model_type diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index af8ecf2c..e3c649f7 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -198,6 +198,23 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=256, + help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256", + ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + # copy from Diffusers parser.add_argument( "--weighting_scheme", @@ -317,36 +334,36 @@ def do_sample( x = noise_scaled.to(device).to(dtype) # print(x.shape) - with torch.no_grad(): - for i in tqdm(range(len(sigmas) - 1)): - sigma_hat = sigmas[i] + # with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] - timestep = model_sampling.timestep(sigma_hat).float() - timestep = torch.FloatTensor([timestep, timestep]).to(device) + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) - x_c_nc = torch.cat([x, x], dim=0) - # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) - model_output = model_output.float() - batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) - pos_out, neg_out = batched.chunk(2) - denoised = neg_out + (pos_out - neg_out) * guidance_scale - # print(denoised.shape) + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) - # d = to_d(x, sigma_hat, denoised) - dims_to_append = x.ndim - sigma_hat.ndim - sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] - # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) - """Converts a denoiser output to a Karras ODE derivative.""" - d = (x - denoised) / sigma_hat_dims + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims - dt = sigmas[i + 1] - sigma_hat + dt = sigmas[i + 1] - sigma_hat - # Euler method - x = x + d * dt - x = x.to(dtype) + # Euler method + x = x + d * dt + x = x.to(dtype) return x @@ -378,7 +395,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -386,7 +403,7 @@ def sample_images( # unwrap unet and text_encoder(s) mmdit = accelerator.unwrap_model(mmdit) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -404,7 +421,7 @@ def sample_images( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: sample_image_inference( accelerator, @@ -506,29 +523,39 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - te_outputs = sample_prompts_te_outputs[prompt] - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) - te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # encode negative prompts - if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: - neg_te_outputs = sample_prompts_te_outputs[negative_prompt] - else: - l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt) - neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) - - lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) # sample image clean_memory_on_device(accelerator.device) - with accelerator.autocast(): - latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) + with accelerator.autocast(), torch.no_grad(): + # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype. + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device) # latent to image clean_memory_on_device(accelerator.device) @@ -538,7 +565,7 @@ def sample_image_inference( image = vae.decode(latents) vae.to(org_vae_device) clean_memory_on_device(accelerator.device) - + image = image.float() image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 9ad995d8..71e50de3 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -91,7 +91,7 @@ def load_mmdit( mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True) logger.info(f"Loaded MMDiT: {info}") return mmdit diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 00000000..cbabf8da --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,826 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import numpy as np +import torch +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from networks.lora_flux import LoRAModule, LoRAInfModule +from library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "t_embedder", + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + qkv_dim, rank = up_weights[0].size() + split_dim = qkv_dim // 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/sd3_train.py b/sd3_train.py index 5e2efa6f..d12f7f56 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -220,12 +220,7 @@ def train(args): sd3_state_dict = None # load tokenizer and prepare tokenize strategy - if args.t5xxl_max_token_length is None: - t5xxl_max_token_length = 256 # default value for T5XXL - else: - t5xxl_max_token_length = args.t5xxl_max_token_length - - sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length) + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) # load clip_l, clip_g, t5xxl for caching text encoder outputs @@ -876,6 +871,9 @@ def train(args): lg_out = None t5_out = None lg_pooled = None + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None if lg_out is None: # not cached or training, so get from text encoders @@ -885,7 +883,7 @@ def train(args): # text models in sd3_models require "cpu" for input_ids input_ids_clip_l = input_ids_clip_l.to("cpu") input_ids_clip_g = input_ids_clip_g.to("cpu") - lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( + lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], @@ -895,7 +893,7 @@ def train(args): _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None - _, t5_out, _ = text_encoding_strategy.encode_tokens( + _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) @@ -1104,22 +1102,6 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" ) - parser.add_argument( - "--t5xxl_max_token_length", - type=int, - default=None, - help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256", - ) - parser.add_argument( - "--apply_lg_attn_mask", - action="store_true", - help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", - ) - parser.add_argument( - "--apply_t5_attn_mask", - action="store_true", - help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", - ) parser.add_argument( "--learning_rate_te1", diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 00000000..0f4ca93e --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,427 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library import strategy_sd3, utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + sd3_train_utils.sample_images( + accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs + ) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # shift 3.0 is the default value + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sd3_train_utils.add_sd3_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = Sd3NetworkTrainer() + trainer.train(args) diff --git a/train_network.py b/train_network.py index 9943b60b..aab1d84b 100644 --- a/train_network.py +++ b/train_network.py @@ -129,6 +129,7 @@ class NetworkTrainer: def get_models_for_text_encoding(self, args, accelerator, text_encoders): """ Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). """ return text_encoders @@ -591,6 +592,7 @@ class NetworkTrainer: # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.requires_grad_(False) From 0031d916f0fa035d5d48a25fcabadc149bfbb639 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 25 Oct 2024 23:20:38 +0900 Subject: [PATCH 263/356] add latent scaling/shifting --- sd3_train_network.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sd3_train_network.py b/sd3_train_network.py index 0f4ca93e..ecacf16c 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -6,7 +6,7 @@ from typing import Any, Optional import torch from accelerate import Accelerator -from library import strategy_sd3, utils +from library import sd3_models, strategy_sd3, utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -268,7 +267,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): return vae.encode(images) def shift_scale_latents(self, args, latents): - return latents + return sd3_models.SDVAE.process_in(latents) def get_noise_pred_and_target( self, From 56bf7611644402996072bd8f909cf828ec7b27cc Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 26 Oct 2024 17:29:24 +0900 Subject: [PATCH 264/356] fix errors in SD3 LoRA training with Text Encoders close #1724 --- library/strategy_sd3.py | 26 +++++++++++++------------- sd3_train_network.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index dd08cf00..a27e99e6 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -68,9 +68,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): returned embeddings are not masked """ clip_l, clip_g, t5xxl = models - clip_l: CLIPTextModel - clip_g: CLIPTextModelWithProjection - t5xxl: T5EncoderModel + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] if apply_lg_attn_mask is None: apply_lg_attn_mask = self.apply_lg_attn_mask @@ -84,25 +84,23 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): if not apply_lg_attn_mask: l_attn_mask = None g_attn_mask = None - else: - l_attn_mask = l_attn_mask.to(clip_l.device) - g_attn_mask = g_attn_mask.to(clip_g.device) if not apply_t5_attn_mask: t5_attn_mask = None - else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) else: l_attn_mask = None g_attn_mask = None t5_attn_mask = None - if l_tokens is None: + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None else: with torch.no_grad(): assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None + g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None + prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) l_pooled = prompt_embeds[0] l_out = prompt_embeds.hidden_states[-2] @@ -114,13 +112,15 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None lg_out = torch.cat([l_out, g_out], dim=-1) - if t5xxl is not None and t5_tokens is not None: + if t5xxl is None or t5_tokens is None: + t5_out = None + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None with torch.no_grad(): t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) - else: - t5_out = None - return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor diff --git a/sd3_train_network.py b/sd3_train_network.py index ecacf16c..129afed5 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -134,7 +134,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def get_models_for_text_encoding(self, args, accelerator, text_encoders): if args.cache_text_encoder_outputs: if self.train_clip and not self.train_t5xxl: - return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached else: return None # no text encoders are needed for encoding because both are cached else: From 014064fd8186420abf5dfc7c99ad0b39fee33f8a Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 26 Oct 2024 18:59:45 +0900 Subject: [PATCH 265/356] fix sample image generation without seed failed close #1726 --- library/sd3_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index e3c649f7..b04b86fb 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -316,6 +316,8 @@ def do_sample( # noise = get_noise(seed, latent).to(device) if seed is not None: generator = torch.manual_seed(seed) + else: + generator = None noise = ( torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") .to(latent.dtype) From 56b4ea963ee745b73be1ad53c602854e3ff7ed16 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 26 Oct 2024 22:01:10 +0900 Subject: [PATCH 266/356] Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722 --- README.md | 8 ++++++++ networks/resize_lora.py | 15 ++++++++------- networks/sdxl_merge_lora.py | 15 ++++++++------- networks/svd_merge_lora.py | 15 ++++++++------- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 0be2f9a7..7f8508dc 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Oct 26, 2024 / 2024-10-26: + +- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue. +- It will be included in the next release. + +- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。 +- 以上は次回リリースに含まれます。 + ### Sep 13, 2024 / 2024-09-13: - `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). diff --git a/networks/resize_lora.py b/networks/resize_lora.py index d697baa4..7df7ef0c 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if model_util.is_safetensors(file_name): save_file(state_dict, file_name, metadata) else: @@ -349,12 +344,18 @@ def resize(args): metadata["ss_network_dim"] = "Dynamic" metadata["ss_network_alpha"] = "Dynamic" + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 62f5a87d..b147eb44 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -35,12 +35,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, model, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(model, file_name, metadata=metadata) else: @@ -430,6 +425,12 @@ def merge(args): else: state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) @@ -445,7 +446,7 @@ def merge(args): metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index b4b9e3bf..c520e7f8 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -216,12 +216,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(state_dict, file_name, metadata=metadata) else: @@ -430,6 +425,12 @@ def merge(args): args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) @@ -451,7 +452,7 @@ def merge(args): metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: From ca44e3e447fc1185ce188229b4e1a0f7f3bbbf66 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Oct 2024 10:19:05 +0900 Subject: [PATCH 267/356] reduce VRAM usage, instead of increasing main RAM usage --- README.md | 7 +++++++ networks/svd_merge_lora.py | 11 +++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7f8508dc..1e7d49af 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Oct 27, 2024 / 2024-10-27: + +- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient). +- This will be included in the next release. +- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。 +- これは次回リリースに含まれます。 + ### Oct 26, 2024 / 2024-10-26: - Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue. diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index c520e7f8..c79b45ac 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -301,10 +301,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer # make original weight if not exist if lora_module_name not in merged_sd: weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) - if device: - weight = weight.to(device) else: weight = merged_sd[lora_module_name] + if device: + weight = weight.to(device) # merge to weight if device: @@ -336,13 +336,16 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale - merged_sd[lora_module_name] = weight + merged_sd[lora_module_name] = weight.to("cpu") # extract from merged weights logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): + if device: + mat = mat.to(device) + conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) @@ -381,7 +384,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) + merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu") # build minimum metadata dims = f"{new_rank}" From db2b4d41b9637cffd40a694c8e25847446a57aad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Oct 2024 16:42:58 +0900 Subject: [PATCH 268/356] Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained --- library/sd3_train_utils.py | 18 ++++++++ library/strategy_sd3.py | 93 ++++++++++++++++++++++++++++++++++---- sd3_train.py | 15 ++++-- sd3_train_network.py | 16 ++++++- train_network.py | 13 ++++-- 5 files changed, 138 insertions(+), 17 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index b04b86fb..a0202ad4 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -214,6 +214,24 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) # copy from Diffusers parser.add_argument( diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index a27e99e6..d87ad7d1 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -1,5 +1,6 @@ import os import glob +import random from typing import Any, List, Optional, Tuple, Union import torch import numpy as np @@ -48,13 +49,23 @@ class Sd3TokenizeStrategy(TokenizeStrategy): class Sd3TextEncodingStrategy(TextEncodingStrategy): - def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None: + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: """ Args: apply_t5_attn_mask: Default value for apply_t5_attn_mask. """ self.apply_lg_attn_mask = apply_lg_attn_mask self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate def encode_tokens( self, @@ -63,6 +74,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): tokens: List[torch.Tensor], apply_lg_attn_mask: Optional[bool] = False, apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, ) -> List[torch.Tensor]: """ returned embeddings are not masked @@ -91,37 +103,92 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): g_attn_mask = None t5_attn_mask = None + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + if l_tokens is None or clip_l is None: assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None else: - with torch.no_grad(): - assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + if drop_l: + l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) + l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) + if l_attn_mask is not None: + l_attn_mask = torch.zeros_like(l_attn_mask) + else: l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None - g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None - prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) l_pooled = prompt_embeds[0] l_out = prompt_embeds.hidden_states[-2] + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if drop_g: + g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + if g_attn_mask is not None: + g_attn_mask = torch.zeros_like(g_attn_mask) + else: + g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) g_pooled = prompt_embeds[0] g_out = prompt_embeds.hidden_states[-2] - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None - lg_out = torch.cat([l_out, g_out], dim=-1) + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is None or t5_tokens is None: t5_out = None else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None - with torch.no_grad(): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if drop_t5: + t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) + if t5_attn_mask is not None: + t5_attn_mask = torch.zeros_like(t5_attn_mask) + else: + t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) # masks are used for attention masking in transformer return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask + def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -207,8 +274,14 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): + # always disable dropout during caching lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, ) if lg_out.dtype == torch.bfloat16: diff --git a/sd3_train.py b/sd3_train.py index d12f7f56..cdac945e 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -69,6 +69,11 @@ def train(args): # assert ( # not args.train_text_encoder or not args.cache_text_encoder_outputs # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" @@ -232,7 +237,9 @@ def train(args): assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" # prepare text encoding strategy - text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate + ) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) # 学習を準備する:モデルを適切な状態にする @@ -311,6 +318,7 @@ def train(args): tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask, + enable_dropout=False, ) accelerator.wait_for_everyone() @@ -863,6 +871,7 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list if args.use_t5xxl_cache_only: lg_out = None @@ -878,7 +887,7 @@ def train(args): if lg_out is None: # not cached or training, so get from text encoders input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] - with torch.set_grad_enabled(args.train_text_encoder): + with torch.set_grad_enabled(train_clip): # TODO support weighted captions # text models in sd3_models require "cpu" for input_ids input_ids_clip_l = input_ids_clip_l.to("cpu") @@ -891,7 +900,7 @@ def train(args): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] - with torch.no_grad(): + with torch.set_grad_enabled(train_t5xxl): input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] diff --git a/sd3_train_network.py b/sd3_train_network.py index 129afed5..7b547127 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -120,7 +120,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): return latents_caching_strategy def get_text_encoding_strategy(self, args): - return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask) + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5xxl_dropout_rate, + ) def post_process_network(self, args, accelerator, network, text_encoders, unet): # check t5xxl is trained or not @@ -408,6 +414,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() diff --git a/train_network.py b/train_network.py index aab1d84b..9d78a4ef 100644 --- a/train_network.py +++ b/train_network.py @@ -272,6 +272,9 @@ class NetworkTrainer: def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + # endregion def train(self, args): @@ -1030,9 +1033,9 @@ class NetworkTrainer: # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: - on_step_start = lambda *args, **kwargs: None + on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -1113,7 +1116,10 @@ class NetworkTrainer: continue with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -1146,6 +1152,7 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: From a1255d637f545b0d6defebf080ca31f2370bf311 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 17:03:36 +0900 Subject: [PATCH 269/356] Fix SD3 LoRA training to work (WIP) --- library/strategy_sd3.py | 20 ++++++++++---------- sd3_train_network.py | 15 ++++++++------- train_network.py | 20 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index d87ad7d1..e57bb337 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -111,13 +111,13 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): lg_pooled = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) if drop_l: - l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) - l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) + l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype) + l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype) if l_attn_mask is not None: - l_attn_mask = torch.zeros_like(l_attn_mask) + l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device) else: l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) @@ -126,10 +126,10 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) if drop_g: - g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) - g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype) + g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype) if g_attn_mask is not None: - g_attn_mask = torch.zeros_like(g_attn_mask) + g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device) else: g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) @@ -144,9 +144,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): else: drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) if drop_t5: - t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) + t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype) if t5_attn_mask is not None: - t5_attn_mask = torch.zeros_like(t5_attn_mask) + t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device) else: t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) @@ -187,7 +187,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): if t5_attn_mask is not None: t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) - return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor diff --git a/sd3_train_network.py b/sd3_train_network.py index 7b547127..620a336f 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -125,7 +125,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, - args.t5xxl_dropout_rate, + args.t5_dropout_rate, ) def post_process_network(self, args, accelerator, network, text_encoders, unet): @@ -415,12 +415,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # drop cached text encoder outputs - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) - batch["text_encoder_outputs_list"] = text_encoder_outputs_list + # # drop cached text encoder outputs + # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + # if text_encoder_outputs_list is not None: + # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + # batch["text_encoder_outputs_list"] = text_encoder_outputs_list + pass def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 9d78a4ef..76936b2e 100644 --- a/train_network.py +++ b/train_network.py @@ -1151,6 +1151,17 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + # if text_encoder_outputs_list is not None: + # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list + # for i in range(len(lg_out)): + # print( + # f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " + # f"cached T5: {t5_out[i].max()}, " + # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," + # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" + # ) + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): @@ -1182,6 +1193,15 @@ class NetworkTrainer: if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] + # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + # for i in range(len(lg_out)): + # print( + # f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " + # f"train T5: {t5_out[i].max()}, " + # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," + # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" + # ) + # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( args, From d4f7849592c78455ddd268423528830ec5e55f47 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 19:35:56 +0900 Subject: [PATCH 270/356] prevent unintended cast for disk cached TE outputs --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d3c59ef9..d568523c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1615,7 +1615,6 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs = [torch.FloatTensor(x) for x in text_encoder_outputs] else: tokenization_required = True text_encoder_outputs_list.append(text_encoder_outputs) From 1065dd1b56b4b18e211d3827fe22b459c81dd12c Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 19:36:36 +0900 Subject: [PATCH 271/356] Fix to work dropout_rate for TEs --- flux_train_network.py | 2 +- library/strategy_flux.py | 1 + library/strategy_sd3.py | 142 +++++++++++++++++++++++++++------------ sd3_train_network.py | 15 ++--- train_network.py | 19 ------ 5 files changed, 108 insertions(+), 71 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index cffeb3b1..2b71a897 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -363,7 +363,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 0b0c34af..f662b62e 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -190,6 +190,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index e57bb337..413169ec 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -89,19 +89,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - l_tokens, g_tokens, t5_tokens = tokens[:3] - - if len(tokens) > 3: - l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] - if not apply_lg_attn_mask: - l_attn_mask = None - g_attn_mask = None - if not apply_t5_attn_mask: - t5_attn_mask = None - else: - l_attn_mask = None - g_attn_mask = None - t5_attn_mask = None + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings @@ -109,47 +97,114 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None + l_attn_mask = None + g_attn_mask = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) - if drop_l: - l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype) - l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype) - if l_attn_mask is not None: - l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device) - else: - l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None - prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) - l_pooled = prompt_embeds[0] - l_out = prompt_embeds.hidden_states[-2] + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] - drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) - if drop_g: - g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype) - g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype) - if g_attn_mask is not None: - g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device) - else: - g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None - prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) - g_pooled = prompt_embeds[0] - g_out = prompt_embeds.hidden_states[-2] + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out + else: + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out + else: + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is None or t5_tokens is None: t5_out = None + t5_attn_mask = None else: - drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) - if drop_t5: - t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype) - if t5_attn_mask is not None: - t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device) + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) + + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask # masks are used for attention masking in transformer return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] @@ -322,6 +377,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): apply_t5_attn_mask=apply_t5_attn_mask, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) diff --git a/sd3_train_network.py b/sd3_train_network.py index 620a336f..3506404a 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -300,7 +300,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) # Predict the noise residual @@ -415,13 +415,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # # drop cached text encoder outputs - # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - # if text_encoder_outputs_list is not None: - # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) - # batch["text_encoder_outputs_list"] = text_encoder_outputs_list - pass + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 76936b2e..b90aa420 100644 --- a/train_network.py +++ b/train_network.py @@ -1151,16 +1151,6 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - - # if text_encoder_outputs_list is not None: - # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list - # for i in range(len(lg_out)): - # print( - # f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " - # f"cached T5: {t5_out[i].max()}, " - # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," - # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" - # ) if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' @@ -1193,15 +1183,6 @@ class NetworkTrainer: if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds - # for i in range(len(lg_out)): - # print( - # f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " - # f"train T5: {t5_out[i].max()}, " - # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," - # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" - # ) - # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( args, From af8e216035128767234163a24debf2f4df5aa36d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 28 Oct 2024 22:08:57 +0900 Subject: [PATCH 272/356] Fix sample image gen to work with block swap --- library/sd3_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index a0202ad4..054d1b4a 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -364,6 +364,7 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + mmdit.prepare_block_swap_before_forward() model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -385,6 +386,7 @@ def do_sample( x = x + d * dt x = x.to(dtype) + mmdit.prepare_block_swap_before_forward() return x From 75554867ce390ec0957cc52a70c0695e19c71fe2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Oct 2024 08:34:31 +0900 Subject: [PATCH 273/356] Fix error on saving T5XXL --- library/sd3_train_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 054d1b4a..1702e81c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -75,7 +75,14 @@ def save_models( save_file(clip_g.state_dict(), clip_g_path) if t5xxl is not None: t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") - save_file(t5xxl.state_dict(), t5xxl_path) + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) def save_sd3_model_on_train_end( From 0af4edd8a63d7fcdf02bdcbd11b8770fd1cae162 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 21:51:56 +0900 Subject: [PATCH 274/356] Fix split_qkv --- networks/lora_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index cbabf8da..249298b3 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -540,8 +540,8 @@ class LoRANetwork(torch.nn.Module): down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) # merge up weight (sum of split_dim, rank*3) - qkv_dim, rank = up_weights[0].size() - split_dim = qkv_dim // 3 + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) i = 0 for j in range(3): From d4e19fbd5e34e90347f189a8ba1f77e8878fe0ca Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 21:52:04 +0900 Subject: [PATCH 275/356] Support Lora --- sd3_minimal_inference.py | 60 +++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index d099fe18..86dba246 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -10,11 +10,13 @@ import numpy as np import torch from safetensors.torch import safe_open, load_file +import torch.amp from tqdm import tqdm from PIL import Image from transformers import CLIPTextModelWithProjection, T5EncoderModel from library.device_utils import init_ipex, get_preferred_device +from networks import lora_sd3 init_ipex() @@ -104,7 +106,8 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + with torch.autocast(device_type=device.type, dtype=dtype): + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -153,7 +156,7 @@ def generate_image( clip_g.to(device) t5xxl.to(device) - with torch.no_grad(): + with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad(): tokens_and_masks = tokenize_strategy.tokenize(prompt) lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask @@ -233,13 +236,14 @@ if __name__ == "__main__": parser.add_argument("--bf16", action="store_true") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--steps", type=int, default=50) - # parser.add_argument( - # "--lora_weights", - # type=str, - # nargs="*", - # default=[], - # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", - # ) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -294,6 +298,30 @@ if __name__ == "__main__": tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + # LoRA + lora_models: list[lora_sd3.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + module = lora_sd3 + lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd) + else: + lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + if not args.interactive: generate_image( mmdit, @@ -344,13 +372,13 @@ if __name__ == "__main__": steps = int(opt[1:].strip()) elif opt.startswith("d"): seed = int(opt[1:].strip()) - # elif opt.startswith("m"): - # mutipliers = opt[1:].strip().split(",") - # if len(mutipliers) != len(lora_models): - # logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - # continue - # for i, lora_model in enumerate(lora_models): - # lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) elif opt.startswith("n"): negative_prompt = opt[1:].strip() if negative_prompt == "-": From 1e2f7b0e44ee656cd8d0ca8268aa1371618031ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Oct 2024 22:11:04 +0900 Subject: [PATCH 276/356] Support for checkpoint files with a mysterious prefix "model.diffusion_model." --- library/flux_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/library/flux_utils.py b/library/flux_utils.py index 7a1ec37b..4403835f 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -73,6 +73,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int with safe_open(ckpt_path, framework="pt") as f: keys.extend(f.keys()) + # if the key has annoying prefix, remove it + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) @@ -141,6 +145,13 @@ def load_flow_model( sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return is_schnell, model From ce5b5325829538c03ff9ce80a79fe2c84ca5283c Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 22:29:24 +0900 Subject: [PATCH 277/356] Fix additional LoRA to work --- networks/lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index 249298b3..c1eb68b8 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -428,7 +428,7 @@ class LoRANetwork(torch.nn.Module): for filter, in_dim in zip( [ "context_embedder", - "t_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" "x_embedder", "y_embedder", "final_layer_adaLN_modulation", From b502f584886fbf52f9a180981efe276ea8509de7 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Tue, 29 Oct 2024 23:29:50 +0900 Subject: [PATCH 278/356] Fix emb_dim to work. --- networks/lora_sd3.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index c1eb68b8..efe20245 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module): target_replace_modules: List[str], filter: Optional[str] = None, default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_SD3 @@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module): lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") - if filter is not None and not filter in lora_name: - continue + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter dim = None alpha = None @@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module): elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha if dim is None or dim == 0: # skipした情報を出力 @@ -428,7 +436,7 @@ class LoRANetwork(torch.nn.Module): for filter, in_dim in zip( [ "context_embedder", - "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" "x_embedder", "y_embedder", "final_layer_adaLN_modulation", @@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module): ], self.emb_dims, ): - loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") self.unet_loras.extend(loras) logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") From bdddc20d68a7441cccfcf0009528fdd59403b94a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Oct 2024 12:51:49 +0900 Subject: [PATCH 279/356] support SD3.5M --- library/sd3_models.py | 128 +++++++++++++++++++++++-------------- library/sd3_train_utils.py | 7 ++ library/sd3_utils.py | 13 ++-- sd3_train.py | 8 +-- sd3_train_network.py | 1 + 5 files changed, 99 insertions(+), 58 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 5d09f74e..840f9186 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -51,7 +51,7 @@ class SD3Params: pos_embed_max_size: int adm_in_channels: int qk_norm: Optional[str] - x_block_self_attn_layers: List[int] + x_block_self_attn_layers: list[int] context_embedder_in_features: int context_embedder_out_features: int model_type: str @@ -510,6 +510,7 @@ class SingleDiTBlock(nn.Module): scale_mod_only: bool = False, swiglu: bool = False, qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, **block_kwargs, ): super().__init__() @@ -519,13 +520,14 @@ class SingleDiTBlock(nn.Module): self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) else: self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = AttentionLinears( - dim=hidden_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - pre_only=pre_only, - qk_norm=qk_norm, - ) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + if not pre_only: if not rmsnorm: self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -546,7 +548,9 @@ class SingleDiTBlock(nn.Module): multiple_of=256, ) self.scale_mod_only = scale_mod_only - if not scale_mod_only: + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: n_mods = 6 if not pre_only else 2 else: n_mods = 4 if not pre_only else 1 @@ -556,63 +560,64 @@ class SingleDiTBlock(nn.Module): def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: if not self.pre_only: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(6, dim=-1) + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) else: shift_msa = None shift_mlp = None - ( - scale_msa, - gate_msa, - scale_mlp, - gate_mlp, - ) = self.adaLN_modulation( - c - ).chunk(4, dim=-1) + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) - return qkv, ( - x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - ) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: if not self.scale_mod_only: - ( - shift_msa, - scale_msa, - ) = self.adaLN_modulation( - c - ).chunk(2, dim=-1) + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) else: shift_msa = None scale_msa = self.adaLN_modulation(c) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) return qkv, None + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): assert not self.pre_only x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + # JointBlock + block_mixing in mmdit.py class MMDiTBlock(nn.Module): def __init__(self, *args, **kwargs): super().__init__() pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) - self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + self.head_dim = self.x_block.attn.head_dim self.mode = self.x_block.attn_mode self.gradient_checkpointing = False @@ -622,7 +627,11 @@ class MMDiTBlock(nn.Module): def _forward(self, context, x, c): ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) - x_qkv, x_intermediate = self.x_block.pre_attention(x, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) ctx_len = ctx_qkv[0].size(1) @@ -634,11 +643,18 @@ class MMDiTBlock(nn.Module): ctx_attn_out = attn[:, :ctx_len] x_attn_out = attn[:, ctx_len:] - x = self.x_block.post_attention(x_attn_out, *x_intermediate) + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + if not self.context_block.pre_only: context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) else: context = None + return context, x def forward(self, *args, **kwargs): @@ -678,7 +694,9 @@ class MMDiT(nn.Module): pos_embed_max_size: Optional[int] = None, num_patches=None, qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], qkv_bias: bool = True, + pos_emb_random_crop_rate: float = 0.0, model_type: str = "sd3m", ): super().__init__() @@ -691,6 +709,8 @@ class MMDiT(nn.Module): self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_offset = pos_embed_offset self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate self.gradient_checkpointing = use_checkpoint # hidden_size = default(hidden_size, 64 * depth) @@ -751,6 +771,7 @@ class MMDiT(nn.Module): scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), ) for i in range(depth) ] @@ -832,7 +853,10 @@ class MMDiT(nn.Module): nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, h, w, device=None): + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): p = self.x_embedder.patch_size # patched size h = (h + 1) // p @@ -842,8 +866,14 @@ class MMDiT(nn.Module): assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) - top = (self.pos_embed_max_size - h) // 2 - left = (self.pos_embed_max_size - w) // 2 + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + spatial_pos_embed = self.pos_embed.reshape( 1, self.pos_embed_max_size, @@ -896,9 +926,12 @@ class MMDiT(nn.Module): t: (N,) tensor of diffusion timesteps y: (N, D) tensor of class labels """ + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) B, C, H, W = x.shape - x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype) + x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) @@ -977,6 +1010,7 @@ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: depth=params.depth, mlp_ratio=4, qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, num_patches=params.num_patches, attn_mode=attn_mode, model_type=params.model_type, diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 1702e81c..86f0c9c0 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -239,6 +239,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=0.0, help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) # copy from Diffusers parser.add_argument( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 71e50de3..1861dfbc 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -41,20 +41,21 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) x_block_self_attn_layers = [] - re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight") + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") for key in list(state_dict.keys()): - m = re_attn.match(key) + m = re_attn.search(key) if m: x_block_self_attn_layers.append(int(m.group(1))) - assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported" - context_embedder_in_features = context_shape[1] context_embedder_out_features = context_shape[0] - # only supports 3-5-large and 3-medium + # only supports 3-5-large, medium or 3-medium if qk_norm is not None: - model_type = "3-5-large" + if len(x_block_self_attn_layers) == 0: + model_type = "3-5-large" + else: + model_type = "3-5-medium" else: model_type = "3-medium" diff --git a/sd3_train.py b/sd3_train.py index cdac945e..df273690 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -353,17 +353,15 @@ def train(args): accelerator.wait_for_everyone() # load MMDIT - mmdit = sd3_utils.load_mmdit( - sd3_state_dict, - model_dtype, - "cpu", - ) + mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu") # attn_mode = "xformers" if args.xformers else "torch" # assert ( # attn_mode == "torch" # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() diff --git a/sd3_train_network.py b/sd3_train_network.py index 3506404a..3d2a7571 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -65,6 +65,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): ) mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) if args.fp8_base: # check dtype of model From 70a179e446219b66f208e4fbb37b74c5d77d6086 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Oct 2024 14:34:19 +0900 Subject: [PATCH 280/356] Fix to use SDPA instead of xformers --- library/sd3_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 840f9186..60356e82 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -645,7 +645,7 @@ class MMDiTBlock(nn.Module): if self.x_block.x_block_self_attn: x_q2, x_k2, x_v2 = x_qkv2 - attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads) + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) else: x = self.x_block.post_attention(x_attn_out, *x_intermediates) From 1434d8506f3ccc4ae6cc005a19531dba3cbb9fb9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 19:58:22 +0900 Subject: [PATCH 281/356] Support SD3.5M multi resolutional training --- library/sd3_models.py | 177 ++++++++++++++++++++++++++++++++++++- library/sd3_train_utils.py | 6 ++ library/strategy_base.py | 2 +- library/strategy_flux.py | 4 +- library/strategy_sd3.py | 11 ++- library/train_util.py | 3 + sd3_train.py | 9 +- sd3_train_network.py | 13 ++- 8 files changed, 215 insertions(+), 10 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 60356e82..0eca94e2 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -88,6 +88,78 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): return emb +def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): + """ + This function is contributed by KohakuBlueleaf. Thanks for the contribution! + + Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions + when the resolution differs from the training resolution. + + Args: + embed_dim (int): Dimension of the positional embedding. + grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. + cls_token (bool): Whether to include class token. Defaults to False. + extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. + sample_size (int): Reference resolution (typically training resolution). Defaults to 64. + base_size (int): Base grid size used during training. Defaults to 16. + + Returns: + numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or + (H*W + extra_tokens, embed_dim) if cls_token is True. + """ + # Convert grid_size to tuple if it's an integer + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + # Create normalized grid coordinates (0 to 1) + grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] + + # Calculate scaling factors for height and width + # This ensures that the central region matches the original resolution's embeddings + scale_h = base_size * grid_size[0] / (sample_size) + scale_w = base_size * grid_size[1] / (sample_size) + + # Calculate shift values to center the original resolution's embedding region + # This ensures that the central sample_size x sample_size region has similar + # positional embeddings to the original resolution + shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) + shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) + + # Apply scaling and shifting to create the final grid coordinates + grid_h = grid_h * scale_h - shift_h + grid_w = grid_w * scale_w - shift_w + + # Create 2D grid using meshgrid (note: w goes first) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + # # Calculate the starting indices for the central region + # # This is used for debugging/visualization of the central region + # st_h = (grid_size[0] - sample_size) // 2 + # st_w = (grid_size[1] - sample_size) // 2 + # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) + + # Reshape grid for positional embedding calculation + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + # Generate the sinusoidal positional embeddings + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + # Add zeros for extra tokens (e.g., [CLS] token) if required + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +# if __name__ == "__main__": +# # This is what you get when you load SD3.5 state dict +# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( +# 1536, [384, 384], sample_size=64, base_size=16 +# )).float().unsqueeze(0) + + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position @@ -617,7 +689,7 @@ class MMDiTBlock(nn.Module): self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) - + self.head_dim = self.x_block.attn.head_dim self.mode = self.x_block.attn_mode self.gradient_checkpointing = False @@ -669,6 +741,9 @@ class MMDiT(nn.Module): Diffusion model with a Transformer backbone. """ + # prepare pos_embed for latent size * 2 + POS_EMBED_MAX_RATIO = 1.5 + def __init__( self, input_size: int = 32, @@ -697,6 +772,8 @@ class MMDiT(nn.Module): x_block_self_attn_layers: Optional[list[int]] = [], qkv_bias: bool = True, pos_emb_random_crop_rate: float = 0.0, + use_scaled_pos_embed: bool = False, + pos_embed_latent_sizes: Optional[list[int]] = None, model_type: str = "sd3m", ): super().__init__() @@ -722,6 +799,8 @@ class MMDiT(nn.Module): self.num_heads = num_heads + self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) + self.x_embedder = PatchEmbed( input_size, patch_size, @@ -785,6 +864,43 @@ class MMDiT(nn.Module): self.blocks_to_swap = None self.thread_pool: Optional[ThreadPoolExecutor] = None + def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): + self.use_scaled_pos_embed = use_scaled_pos_embed + + if self.use_scaled_pos_embed: + # # remove pos_embed to free up memory up to 0.4 GB + self.pos_embed = None + + # sort latent sizes in ascending order + latent_sizes = sorted(latent_sizes) + + patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] + + # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape + max_areas = [] + for i in range(1, len(patched_sizes)): + prev_area = patched_sizes[i - 1] ** 2 + area = patched_sizes[i] ** 2 + max_areas.append((prev_area + area) // 2) + + # area of the last latent size, if the latent size exceeds this, error will be raised + max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) + # print("max_areas", max_areas) + + self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] + + self.resolution_pos_embeds = {} + for patched_size in patched_sizes: + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") + + else: + self.resolution_area_to_latent_size = None + self.resolution_pos_embeds = None + @property def model_type(self): return self._model_type @@ -884,6 +1000,54 @@ class MMDiT(nn.Module): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed + def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + + # select pos_embed size based on area + area = h * w + patched_size = None + for area_, patched_size_ in self.resolution_area_to_latent_size: + if area <= area_: + patched_size = patched_size_ + break + if patched_size is None: + raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + + pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + if h > pos_embed_size or w > pos_embed_size: + # fallback to normal pos_embed + logger.warning( + f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + ) + return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop) + + if not random_crop: + top = (pos_embed_size - h) // 2 + left = (pos_embed_size - w) // 2 + else: + top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() + left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() + + pos_embed = self.resolution_pos_embeds[patched_size] + if pos_embed.device != device: + pos_embed = pos_embed.to(device) + # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. + self.resolution_pos_embeds[patched_size] = pos_embed # update device + if pos_embed.dtype != dtype: + pos_embed = pos_embed.to(dtype) + self.resolution_pos_embeds[patched_size] = pos_embed # update dtype + + spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + # print( + # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" + # ) + return spatial_pos_embed + def enable_block_swap(self, num_blocks: int): self.blocks_to_swap = num_blocks @@ -931,7 +1095,16 @@ class MMDiT(nn.Module): ) B, C, H, W = x.shape - x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + + # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + if not self.use_scaled_pos_embed: + pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + else: + # print(f"Using scaled pos_embed for size {H}x{W}") + pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) + x = self.x_embedder(x) + pos_embed + del pos_embed + c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None and self.y_embedder is not None: y = self.y_embedder(y) # (N, D) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 86f0c9c0..69878750 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -246,6 +246,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", ) + parser.add_argument( + "--enable_scaled_pos_embed", + action="store_true", + help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M" + " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", + ) # copy from Diffusers parser.add_argument( diff --git a/library/strategy_base.py b/library/strategy_base.py index e390c5f3..358e42f1 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -518,7 +518,7 @@ class LatentsCachingStrategy: self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ - for SD/SDXL/SD3.0 + for SD/SDXL """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index f662b62e..5e65927f 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -212,7 +212,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] @@ -226,7 +226,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): vae_dtype = vae.dtype self._default_cache_batch_latents( - encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True ) if not train_util.HIGH_VRAM: diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index 413169ec..1d55fe21 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -399,7 +399,12 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -407,7 +412,9 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index d568523c..bd2ff6ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2510,6 +2510,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.verify_bucket_reso_steps(min_steps) + def get_resolutions(self) -> List[Tuple[int, int]]: + return [(dataset.width, dataset.height) for dataset in self.datasets] + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) diff --git a/sd3_train.py b/sd3_train.py index df273690..40f8c7e1 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -361,7 +361,14 @@ def train(args): # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) - + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + resolutions = train_dataset_group.get_resolutions() + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + if args.gradient_checkpointing: mmdit.enable_gradient_checkpointing() diff --git a/sd3_train_network.py b/sd3_train_network.py index 3d2a7571..9eeac05c 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,8 +26,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) if args.fp8_base_unet: @@ -53,6 +53,9 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + # enumerate resolutions from dataset for positional embeddings + self.resolutions = train_dataset_group.get_resolutions() + def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models @@ -67,6 +70,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): self.model_type = mmdit.model_type mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + if args.fp8_base: # check dtype of model if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: From 9e23368e3d6288e85c6fe34f4d5774bd4d948517 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 19:58:41 +0900 Subject: [PATCH 282/356] Update SD3 training --- README.md | 205 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 168 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index ad2791e7..aff78b2c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -## FLUX.1 training (WIP) +## FLUX.1 and SD3 training (WIP) This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. @@ -9,8 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +- [FLUX.1 training](#flux1-training) +- [SD3 training](#sd3-training) + ### Recent Updates +Oct 31, 2024: + +- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. + Oct 19, 2024: - Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. @@ -139,7 +146,7 @@ Sep 1, 2024: Aug 29, 2024: Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. -### Contents +## FLUX.1 training - [FLUX.1 LoRA training](#flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) @@ -586,53 +593,177 @@ python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_fol ## SD3 training -SD3 training is done with `sd3_train.py`. +SD3.5L/M training is now available. -__Sep 1, 2024__: -- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds! +### SD3 LoRA training -__Jul 27, 2024__: -- Latents and text encoder outputs caching mechanism is refactored significantly. - - Existing cache files for SD3 need to be recreated. Please delete the previous cache files. - - With this change, dataset initialization is significantly faster, especially for large datasets. +The script is `sd3_train_network.py`. See `--help` for options. -- Architecture-dependent parts are extracted from the dataset (`train_util.py`). This is expected to make it easier to add future architectures. +SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format. -- Architecture-dependent parts including the cache mechanism for SD1/2/SDXL are also extracted. The basic operation of SD1/2/SDXL training on the sd3 branch has been confirmed, but there may be bugs. Please use the main or dev branch for SD1/2/SDXL training. +Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L). ---- +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py +--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name sd3-lora-name +``` +(The command is multi-line for readability. Please combine it into one line.) -`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. +The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below: -`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. - -`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. - -t5xxl works with `fp16` now. - -There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. - -`text_encoder_batch_size` is added experimentally for caching faster. - -```toml -learning_rate = 1e-6 # seems to depend on the batch size -optimizer_type = "adafactor" -optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] -cache_text_encoder_outputs = true -cache_text_encoder_outputs_to_disk = true -vae_batch_size = 1 -text_encoder_batch_size = 4 -cache_latents = true -cache_latents_to_disk = true +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -__2024/7/27:__ +`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training. -Latents およびテキストエンコーダ出力のキャッシュの仕組みを大きくリファクタリングしました。SD3 用の既存のキャッシュファイルの再作成が必要になりますが、ご了承ください(以前のキャッシュファイルは削除してください)。これにより、特にデータセットの規模が大きい場合のデータセット初期化が大幅に高速化されます。 +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. -データセット (`train_util.py`) からアーキテクチャ依存の部分を切り出しました。これにより将来的なアーキテクチャ追加が容易になると期待しています。 +The trained LoRA model can be used with ComfyUI. -SD1/2/SDXL のキャッシュ機構を含むアーキテクチャ依存の部分も切り出しました。sd3 ブランチの SD1/2/SDXL 学習について、基本的な動作は確認していますが、不具合があるかもしれません。SD1/2/SDXL の学習には main または dev ブランチをお使いください。 +#### Key Options for SD3 LoRA training + +Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. + +- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training. +- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--clip_g` is the path to the CLIP-G model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. +- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. +- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. + +Other options are described below. + +#### Key Features for SD3 LoRA training + +1. CLIP-L, G and T5XXL LoRA Support: + - SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training. + - Remove `--network_train_unet_only` from your command. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time. + - T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The trained LoRA can be used with ComfyUI. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |MMDiT|`--network_train_unet_only`|-|o| + |MMDiT + CLIP-L + CLIP-G|-|-|o (*2)| + |MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L and G LoRA training. + - *3: Not tested yet. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Same as FLUX.1. + +4. CLIP-L/G and T5 Attention Mask Application: + - This function is planned to be implemented in the future. + +5. Multi-resolution Training Support: + - Only for SD3.5M. + - Same as FLUX.1 for data preparation. + - If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. + + +Technical details of multi-resolution training for SD3.5M: + +The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. + +This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf! + + +#### Specify rank for each layer in SD3 LoRA + +You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|context_attn_dim|attn in context_block| +|context_mlp_dim|mlp in context_block| +|context_mod_dim|adaLN_modulation in context_block| +|x_attn_dim|attn in x_block| +|x_mlp_dim|mlp in x_block| +|x_mod_dim|adaLN_modulation in x_block| + +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + +example: +``` +--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True" +``` + +You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "emb_dims=[2,3,4,5,6,7]" +``` + +Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`. + +#### Specify blocks to train in SD3 LoRA training + +You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. + +The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_block_indices=1,2,6-8" +``` + +### Inference for SD3 with LoRA model + +The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options. + +### SD3 fine-tuning + +Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following: + +- `--clip_g` is also available for SD3 fine-tuning. +- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning. +- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning. +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training. +- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning. +- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training. + - CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option. + - If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available. + +### Extract LoRA from SD3 Models + +Not available yet. + +### Convert SD3 LoRA + +Not available yet. + +### Merge LoRA to SD3 checkpoint + +Not available yet. --- From 830df4abcc85ffdfe08b8f97f2c8351c86149af3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 31 Oct 2024 21:39:07 +0900 Subject: [PATCH 283/356] Fix crashing if image is too tall or wide. --- library/sd3_models.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 0eca94e2..15a5b1db 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -868,7 +868,7 @@ class MMDiT(nn.Module): self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # # remove pos_embed to free up memory up to 0.4 GB + # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None # sort latent sizes in ascending order @@ -977,7 +977,7 @@ class MMDiT(nn.Module): # patched size h = (h + 1) // p w = (w + 1) // p - if self.pos_embed is None: + if self.pos_embed is None: # should not happen return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) assert self.pos_embed_max_size is not None assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) @@ -1016,13 +1016,20 @@ class MMDiT(nn.Module): if patched_size is None: raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") - pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) if h > pos_embed_size or w > pos_embed_size: - # fallback to normal pos_embed + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size logger.warning( f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop) + pos_embed_size = max(h, w) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") if not random_crop: top = (pos_embed_size - h) // 2 @@ -1031,7 +1038,6 @@ class MMDiT(nn.Module): top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() - pos_embed = self.resolution_pos_embeds[patched_size] if pos_embed.device != device: pos_embed = pos_embed.to(device) # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. From 9aa6f52ac3c1866d00675daf73c7560b8b76093f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 1 Nov 2024 21:43:21 +0900 Subject: [PATCH 284/356] Fix memory leak in latent caching. bmp failed to cache --- library/train_util.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index bd2ff6ef..18d3cf6c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1082,6 +1082,10 @@ class BaseDataset(torch.utils.data.Dataset): info.image = info.image.result() # future to image caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + # remove image from memory + for info in batch: + info.image = None + # define ThreadPoolExecutor to load images in parallel max_workers = min(os.cpu_count(), len(image_infos)) max_workers = max(1, max_workers // num_processes) # consider multi-gpu @@ -1397,7 +1401,17 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): - return imagesize.get(image_path) + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use cv2 + img = cv2.imread(image_path) + if img is not None: + image_size = (img.shape[1], img.shape[0]) + else: + logger.warning(f"failed to get image size: {image_path}") + image_size = (0, 0) + return image_size def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask) From 82daa98fe865c30a34638acc145d6f4ea8c193db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 1 Nov 2024 21:43:47 +0900 Subject: [PATCH 285/356] remove duplicate resolution for scaled pos embed --- library/sd3_models.py | 3 ++- sd3_train.py | 1 + sd3_train_network.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 15a5b1db..b09a57db 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,8 @@ class MMDiT(nn.Module): # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # sort latent sizes in ascending order + # remove duplcates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] diff --git a/sd3_train.py b/sd3_train.py index 40f8c7e1..f64e2da2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -366,6 +366,7 @@ def train(args): if args.enable_scaled_pos_embed: resolutions = train_dataset_group.get_resolutions() latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes) diff --git a/sd3_train_network.py b/sd3_train_network.py index 9eeac05c..0739e094 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -73,6 +73,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): # set resolutions for positional embeddings if args.enable_scaled_pos_embed: latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes) From e0db59695fb56e6b7f42132b70e4f828820143ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 2 Nov 2024 11:13:04 +0900 Subject: [PATCH 286/356] update multi-res training in SD3.5M --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index aff78b2c..fb087c23 100644 --- a/README.md +++ b/README.md @@ -679,12 +679,16 @@ Other options are described below. 5. Multi-resolution Training Support: - Only for SD3.5M. - Same as FLUX.1 for data preparation. - - If you train with multiple resolutions, specify `--enable_scaled_pos_embed` to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. + - If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__ + + Technical details of multi-resolution training for SD3.5M: -The values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. +SD3.5M does not use scaled positional embeddings for multi-resolution training, and is trained with a single positional embedding. Therefore, this feature is very experimental. + +Generally, in multi-resolution training, the values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf! From 5e32ee26a13394fdee77149c4e96b78c58eabc5e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 2 Nov 2024 15:32:16 +0900 Subject: [PATCH 287/356] fix crashing in DDP training closes #1751 --- sd3_train.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index f64e2da2..e03d1708 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -838,11 +838,31 @@ def train(args): accelerator.log({}, step=0) # show model device and dtype - logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None") - logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None") - logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None") - logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None") - logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None") + logger.info( + f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}" + if mmdit + else "mmdit is None" + ) + logger.info( + f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}" + if clip_l + else "clip_l is None" + ) + logger.info( + f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}" + if clip_g + else "clip_g is None" + ) + logger.info( + f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}" + if t5xxl + else "t5xxl is None" + ) + logger.info( + f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}" + if vae is not None + else "vae is None" + ) loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 From 81c0c965a24ce4f0f86dfa980f803d7616ca46d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 5 Nov 2024 21:22:42 +0900 Subject: [PATCH 288/356] faster block swap --- flux_train.py | 103 +++++++++---------- library/flux_models.py | 138 ++++++++++++++----------- library/utils.py | 222 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 350 insertions(+), 113 deletions(-) diff --git a/flux_train.py b/flux_train.py index 79c44d7b..afddc897 100644 --- a/flux_train.py +++ b/flux_train.py @@ -17,12 +17,14 @@ import math import os from multiprocessing import Value import time -from typing import List +from typing import List, Optional, Tuple, Union import toml from tqdm import tqdm import torch +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -466,45 +468,28 @@ def train(args): # memory efficient block swapping - def get_block_unit(dbl_blocks, sgl_blocks, index: int): - if index < len(dbl_blocks): - return (dbl_blocks[index],) - else: - index -= len(dbl_blocks) - index *= 2 - return (sgl_blocks[index], sgl_blocks[index + 1]) + def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # start_time = time.perf_counter() + # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) + # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, dbl_blocks, sgl_blocks, device): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - for block in blocks_to_cpu: - block = block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - for block in blocks_to_cuda: - block = block.to(dvc, non_blocking=True) + futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - torch.cuda.synchronize() - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - blocks_to_cpu = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cpu) - blocks_to_cuda = get_block_unit(dbl_blocks, sgl_blocks, block_idx_to_cuda) - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda, device - ) - - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: + def wait_blocks_move(block_id, futures): + if block_id not in futures: return - # print(f"Backward: Wait for block {block_idx}") + # print(f"Backward: Wait for block {block_id}") # start_time = time.perf_counter() - future = futures.pop(block_idx) - future.result() - # print(f"Backward: Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") - # torch.cuda.synchronize() + future = futures.pop(block_id) + _, bidx_to_cuda = future.result() + assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" + # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") if args.fused_backward_pass: @@ -513,11 +498,11 @@ def train(args): library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - handled_unit_indices = set() + handled_block_ids = set() n = 1 # only asynchronous purpose, no need to increase this number # n = 2 @@ -530,28 +515,37 @@ def train(args): if parameter.requires_grad: grad_hook = None - if blocks_to_swap: + if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: is_double = param_name.startswith("double_blocks") is_single = param_name.startswith("single_blocks") - if is_double or is_single: + if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: block_idx = int(param_name.split(".")[1]) - unit_idx = block_idx if is_double else num_double_blocks + block_idx // 2 - if unit_idx not in handled_unit_indices: + block_id = (is_double, block_idx) # double or single, block index + if block_id not in handled_block_ids: # swap following (already backpropagated) block - handled_unit_indices.add(unit_idx) + handled_block_ids.add(block_id) # if n blocks were already backpropagated - num_blocks_propagated = num_block_units - unit_idx - 1 + if is_double: + num_blocks = num_double_blocks + blocks_to_swap = double_blocks_to_swap + else: + num_blocks = num_single_blocks + blocks_to_swap = single_blocks_to_swap + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = num_blocks - block_idx - 2 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap + waiting = block_idx > 0 and block_idx <= blocks_to_swap + if swapping or waiting: - block_idx_to_cpu = num_block_units - num_blocks_propagated + block_idx_to_cpu = num_blocks - num_blocks_propagated block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = unit_idx - 1 + block_idx_to_wait = block_idx - 1 # create swap hook def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, uidx: int, swpng: bool, wtng: bool + is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool ): def __grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -559,24 +553,25 @@ def train(args): optimizer.step_param(tensor, param_group) tensor.grad = None - # print(f"Backward: {uidx}, {swpng}, {wtng}") + # print( + # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" + # ) if swpng: submit_move_blocks( futures, thread_pool, bidx_to_cpu, bidx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, + flux.double_blocks if is_dbl else flux.single_blocks, + (is_dbl, bidx_to_cuda), # wait for this block ) if wtng: - wait_blocks_move(bidx_to_wait, futures) + wait_blocks_move((is_dbl, bidx_to_wait), futures) return __grad_hook grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, unit_idx, swapping, waiting + is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting ) if grad_hook is None: diff --git a/library/flux_models.py b/library/flux_models.py index 0bc1c02b..48dea4fc 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -7,8 +7,9 @@ from dataclasses import dataclass import math import os import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -923,7 +924,8 @@ class Flux(nn.Module): self.blocks_to_swap = None self.thread_pool: Optional[ThreadPoolExecutor] = None - self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2 + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) @property def device(self): @@ -963,14 +965,17 @@ class Flux(nn.Module): def enable_block_swap(self, num_blocks: int): self.blocks_to_swap = num_blocks + self.double_blocks_to_swap = num_blocks // 2 + self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." + ) n = 1 # async block swap. 1 is enough - # n = 2 - # n = max(1, os.cpu_count() // 2) self.thread_pool = ThreadPoolExecutor(max_workers=n) def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_double_blocks = self.double_blocks save_single_blocks = self.single_blocks @@ -983,31 +988,55 @@ class Flux(nn.Module): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - def get_block_unit(self, index: int): - if index < len(self.double_blocks): - return (self.double_blocks[index],) - else: - index -= len(self.double_blocks) - index *= 2 - return self.single_blocks[index], self.single_blocks[index + 1] + # def get_block_unit(self, index: int): + # if index < len(self.double_blocks): + # return (self.double_blocks[index],) + # else: + # index -= len(self.double_blocks) + # index *= 2 + # return self.single_blocks[index], self.single_blocks[index + 1] - def get_unit_index(self, is_double: bool, index: int): - if is_double: - return index - else: - return len(self.double_blocks) + index // 2 + # def get_unit_index(self, is_double: bool, index: int): + # if is_double: + # return index + # else: + # return len(self.double_blocks) + index // 2 def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu + # # make: first n blocks are on cuda, and last n blocks are on cpu + # if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # # raise ValueError("Block swap is not enabled.") + # return + # for i in range(self.num_block_units - self.blocks_to_swap): + # for b in self.get_block_unit(i): + # b.to(self.device) + # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): + # for b in self.get_block_unit(i): + # b.to("cpu") + # clean_memory_on_device(self.device) + + # all blocks are on device, but some weights are on cpu + # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: # raise ValueError("Block swap is not enabled.") return - for i in range(self.num_block_units - self.blocks_to_swap): - for b in self.get_block_unit(i): - b.to(self.device) - for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - for b in self.get_block_unit(i): - b.to("cpu") + + for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() + clean_memory_on_device(self.device) + + for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: + b.to(self.device) + utils.weighs_to_device(b, self.device) # make sure weights are on device + for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: + b.to(self.device) # move block to device first + utils.weighs_to_device(b, "cpu") # make sure weights are on cpu + torch.cuda.synchronize() clean_memory_on_device(self.device) def forward( @@ -1044,27 +1073,22 @@ class Flux(nn.Module): for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - futures = {} + # device = self.device - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - for block in blocks_to_cpu: - block.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Moving {bidx_to_cuda} to cuda.") - for block in blocks_to_cuda: - block.to(self.device, non_blocking=True) - - torch.cuda.synchronize() + def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + start_time = time.perf_counter() + # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") + utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - blocks_to_cpu = self.get_block_unit(block_idx_to_cpu) - blocks_to_cuda = self.get_block_unit(block_idx_to_cuda) + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + return block_idx_to_cpu, block_idx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda) + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) def wait_for_blocks_move(block_idx, ftrs): if block_idx not in ftrs: @@ -1073,37 +1097,35 @@ class Flux(nn.Module): # start_time = time.perf_counter() ftr = ftrs.pop(block_idx) ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") + double_futures = {} for block_idx, block in enumerate(self.double_blocks): # print(f"Double block {block_idx}") - unit_idx = self.get_unit_index(is_double=True, index=block_idx) - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, double_futures) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.double_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx + future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) + double_futures[block_idx_to_cuda] = future img = torch.cat((txt, img), 1) + single_futures = {} for block_idx, block in enumerate(self.single_blocks): # print(f"Single block {block_idx}") - unit_idx = self.get_unit_index(is_double=False, index=block_idx) - if block_idx % 2 == 0: - wait_for_blocks_move(unit_idx, futures) + wait_for_blocks_move(block_idx, single_futures) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap: - block_idx_to_cpu = unit_idx - block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + if block_idx < self.single_blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) + single_futures[block_idx_to_cuda] = future img = img[:, txt.shape[1] :, ...] diff --git a/library/utils.py b/library/utils.py index ca0f904d..aed51007 100644 --- a/library/utils.py +++ b/library/utils.py @@ -6,6 +6,7 @@ import json import struct import torch +import torch.nn as nn from torchvision import transforms from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete @@ -93,6 +94,225 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils +# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") +# # cpu_tensor = module_to_cuda.weight.data +# # cuda_tensor = module_to_cpu.weight.data +# # assert cuda_tensor.device.type == "cuda" +# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) +# # torch.cuda.current_stream().synchronize() +# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) +# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor +# cuda_tensor_view = module_to_cpu.weight.data +# cpu_tensor_view = module_to_cuda.weight.data +# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() +# module_to_cuda.weight.data = cuda_tensor_view +# module_to_cuda.weight.data.copy_(cpu_tensor_view) + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + events = [] + with torch.cuda.stream(stream_to_cpu): + # cuda to offload + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + with torch.cuda.stream(stream_to_cuda): + # cpu to cuda + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): + event.synchronize() + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( + weight_swap_jobs, offloaded_weights + ): + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + stream_to_cpu = torch.cuda.Stream() + stream_to_cuda = torch.cuda.Stream() + + # cuda to offload + events = [] + with torch.cuda.stream(stream_to_cpu): + offloaded_weights = [] + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream_to_cpu) + offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) + + event = torch.cuda.Event() + event.record(stream=stream_to_cpu) + events.append(event) + + # cpu to cuda + with torch.cuda.stream(stream_to_cuda): + for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( + weight_swap_jobs, events, offloaded_weights + ): + event.synchronize() + cuda_data_view.record_stream(stream_to_cuda) + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + module_to_cpu.weight.data = offloaded_weight + + stream_to_cuda.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + # torch.cuda.current_stream().wait_stream(stream_to_cuda) + # for job in weight_swap_jobs: + # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor + + +def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): + # one of the modules must have the tensor to offload + module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.offloaded_weight.pin_memory() + offloaded_weight = ( + module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight + ) + assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" + weight_swap_jobs.append( + (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) + ) + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to offload + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.record_stream(stream) + offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + # offload to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: + module_to_cpu.weight.data = offloaded_weight + offloaded_weight = cpu_data_view + module_to_cpu.offloaded_weight = offloaded_weight + module_to_cuda.offloaded_weight = offloaded_weight + + stream.synchronize() + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): + # one of the modules must have the tensor to cache + module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") + module_to_cpu.__cached_cpu_weight.pin_memory() + + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) + module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) + + torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish + torch.cuda.empty_cache() + + +# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ +# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): +# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: +# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" +# weight_on_cuda = module_to_cpu.weight +# weight_on_cpu = module_to_cuda.weight +# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) +# event = torch.cuda.current_stream().record_event() +# event.synchronize() +# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) +# weight_on_cpu.data = cuda_to_cpu_data +# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad + +# module_to_cpu.weight = weight_on_cpu +# module_to_cuda.weight = weight_on_cuda + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ @@ -313,6 +533,7 @@ class MemoryEfficientSafeOpen: # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") + def load_safetensors( path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 ) -> dict[str, torch.Tensor]: @@ -336,7 +557,6 @@ def load_safetensors( return state_dict - # endregion # region Image utils From aab943cea3eb8a91041c857771f1642581133608 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 5 Nov 2024 23:27:41 +0900 Subject: [PATCH 289/356] remove unused weight swapping functions from utils.py --- library/utils.py | 185 ----------------------------------------------- 1 file changed, 185 deletions(-) diff --git a/library/utils.py b/library/utils.py index aed51007..07079c6d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False): # region PyTorch utils -# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): -# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ -# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): -# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: -# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.") -# # cpu_tensor = module_to_cuda.weight.data -# # cuda_tensor = module_to_cpu.weight.data -# # assert cuda_tensor.device.type == "cuda" -# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True) -# # torch.cuda.current_stream().synchronize() -# # cuda_tensor.copy_(cpu_tensor, non_blocking=True) -# # torch.cuda.current_stream().synchronize() -# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True) -# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor -# cuda_tensor_view = module_to_cpu.weight.data -# cpu_tensor_view = module_to_cuda.weight.data -# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone() -# module_to_cuda.weight.data = cuda_tensor_view -# module_to_cuda.weight.data.copy_(cpu_tensor_view) - def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ @@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): torch.cuda.current_stream().synchronize() # this prevents the illegal loss value -def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - stream_to_cpu = torch.cuda.Stream() - stream_to_cuda = torch.cuda.Stream() - - events = [] - with torch.cuda.stream(stream_to_cpu): - # cuda to offload - offloaded_weights = [] - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) - event = torch.cuda.Event() - event.record(stream=stream_to_cpu) - events.append(event) - - with torch.cuda.stream(stream_to_cuda): - # cpu to cuda - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events): - event.synchronize() - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - # offload to cpu - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip( - weight_swap_jobs, offloaded_weights - ): - module_to_cpu.weight.data = offloaded_weight - - stream_to_cuda.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - - -def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - stream_to_cpu = torch.cuda.Stream() - stream_to_cuda = torch.cuda.Stream() - - # cuda to offload - events = [] - with torch.cuda.stream(stream_to_cpu): - offloaded_weights = [] - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.record_stream(stream_to_cpu) - offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True)) - - event = torch.cuda.Event() - event.record(stream=stream_to_cpu) - events.append(event) - - # cpu to cuda - with torch.cuda.stream(stream_to_cuda): - for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip( - weight_swap_jobs, events, offloaded_weights - ): - event.synchronize() - cuda_data_view.record_stream(stream_to_cuda) - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - module_to_cpu.weight.data = offloaded_weight - - stream_to_cuda.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - # torch.cuda.current_stream().wait_stream(stream_to_cuda) - # for job in weight_swap_jobs: - # job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor - - -def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")): - # one of the modules must have the tensor to offload - module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") - module_to_cpu.offloaded_weight.pin_memory() - offloaded_weight = ( - module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight - ) - assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu" - weight_swap_jobs.append( - (module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight) - ) - - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # cuda to offload - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - cuda_data_view.record_stream(stream) - offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True) - - stream.synchronize() - - # cpu to cuda - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view - - # offload to cpu - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs: - module_to_cpu.weight.data = offloaded_weight - offloaded_weight = cpu_data_view - module_to_cpu.offloaded_weight = offloaded_weight - module_to_cuda.offloaded_weight = offloaded_weight - - stream.synchronize() - - torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - - -def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): - assert layer_to_cpu.__class__ == layer_to_cuda.__class__ - - weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")): - # one of the modules must have the tensor to cache - module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu") - module_to_cpu.__cached_cpu_weight.pin_memory() - - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) - - for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs: - module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True) - module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True) - - torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish - torch.cuda.empty_cache() - - -# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): -# assert layer_to_cpu.__class__ == layer_to_cuda.__class__ -# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): -# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: -# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda" -# weight_on_cuda = module_to_cpu.weight -# weight_on_cpu = module_to_cuda.weight -# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True) -# event = torch.cuda.current_stream().record_event() -# event.synchronize() -# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True) -# weight_on_cpu.data = cuda_to_cpu_data -# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad - -# module_to_cpu.weight = weight_on_cpu -# module_to_cuda.weight = weight_on_cuda - - def weighs_to_device(layer: nn.Module, device: torch.device): for module in layer.modules(): if hasattr(module, "weight") and module.weight is not None: From 43849030cf35a7c854311e0bee9cb8a92b77dd83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 6 Nov 2024 21:33:28 +0900 Subject: [PATCH 290/356] Fix to work without latent cache #1758 --- sd3_train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index e03d1708..b8a0d04f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -885,7 +885,9 @@ def train(args): else: with torch.no_grad(): # encode images to latents. images are [-1, 1] - latents = vae.encode(batch["images"]) + latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): @@ -927,7 +929,7 @@ def train(args): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.set_grad_enabled(train_t5xxl): - input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] ) From 40ed54bfc0ca666c45a4a5d4b7a3064612371005 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 09:53:54 +0000 Subject: [PATCH 291/356] Simplify Timestep weighting * Remove diffusers dependency in ts & sigma calc * support Shift setting * Add uniform distribution * Default to Uniform distribution and shift 1 --- library/sd3_train_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 69878750..bfe752d5 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # copy from Diffusers + # Dependencies of Diffusers noise sampler has been removed for clearity. parser.add_argument( "--weighting_scheme", type=str, - default="logit_normal", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", ) parser.add_argument( @@ -279,8 +279,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", ) - - + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -965,14 +970,20 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift - # Add noise according to flow matching. - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) + + indices = (u * (t_max-t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to dlowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - -# endregion From e54462a4a9cb3d01c5635f8c191d28cbccfba6e0 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 09:54:12 +0000 Subject: [PATCH 292/356] Fix SD3 trained lora loading and merging --- networks/lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index efe20245..ce6d1a16 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -601,7 +601,7 @@ class LoRANetwork(torch.nn.Module): or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) ): apply_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT): + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): apply_unet = True if apply_text_encoder: From bafd10d558bf318ccd7059c2b4dce2775b5758da Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 18:21:04 +0800 Subject: [PATCH 293/356] Fix typo --- library/sd3_train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index bfe752d5..afbe34cf 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -980,7 +980,7 @@ def get_noisy_model_input_and_timesteps( indices = (u * (t_max-t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) - # sigmas according to dlowmatching + # sigmas according to flowmatching sigmas = timesteps / 1000 sigmas = sigmas.view(-1,1,1,1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents From 5e86323f12178605c0b99bc914b4bd970900ce75 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 21:27:12 +0900 Subject: [PATCH 294/356] Update README and clean-up the code for SD3 timesteps --- README.md | 13 ++++++++++++- library/config_util.py | 2 +- library/sd3_models.py | 2 +- library/sd3_train_utils.py | 17 +++++++++-------- sd3_train.py | 8 ++++---- sd3_train_network.py | 7 +++---- 6 files changed, 30 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index fb087c23..dba76a3c 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,13 @@ The command to install PyTorch is as follows: ### Recent Updates +Nov 7, 2024: + +- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! + - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. + - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). + - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + Oct 31, 2024: - Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. @@ -641,6 +648,7 @@ Here are the arguments. The arguments and sample settings are still experimental - `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. - `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. - `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. +- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. Other options are described below. @@ -681,7 +689,10 @@ Other options are described below. - Same as FLUX.1 for data preparation. - If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__ - +6. Weighting scheme and training shift: + - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). + - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. + - `--training_shift` is the shift value for the training distribution of timesteps. Technical details of multi-resolution training for SD3.5M: diff --git a/library/config_util.py b/library/config_util.py index fc1fbf46..12d0be17 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu secondary_separator: {subset.secondary_separator} enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/sd3_models.py b/library/sd3_models.py index b09a57db..89225fe4 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,7 @@ class MMDiT(nn.Module): # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # remove duplcates and sort latent sizes in ascending order + # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index afbe34cf..38f3c25f 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -253,7 +253,7 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # Dependencies of Diffusers noise sampler has been removed for clearity. + # Dependencies of Diffusers noise sampler has been removed for clarity. parser.add_argument( "--weighting_scheme", type=str, @@ -285,7 +285,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): default=1.0, help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", ) - + + def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: @@ -956,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +# endregion + + +def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz = latents.shape[0] # Sample a random timestep for each image @@ -977,13 +979,12 @@ def get_noisy_model_input_and_timesteps( # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) u = (u * shift) / (1 + (shift - 1) * u) - indices = (u * (t_max-t_min) + t_min).long() + indices = (u * (t_max - t_min) + t_min).long() timesteps = indices.to(device=device, dtype=dtype) # sigmas according to flowmatching sigmas = timesteps / 1000 - sigmas = sigmas.view(-1,1,1,1) + sigmas = sigmas.view(-1, 1, 1, 1) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input, timesteps, sigmas - diff --git a/sd3_train.py b/sd3_train.py index b8a0d04f..24ecbfb7 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -811,8 +811,8 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + # noise_scheduler_copy = copy.deepcopy(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} @@ -940,11 +940,11 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - bsz = latents.shape[0] + # bsz = latents.shape[0] # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + args, latents, noise, accelerator.device, weight_dtype ) # debug: NaN check for all inputs diff --git a/sd3_train_network.py b/sd3_train_network.py index 0739e094..bb02c7ac 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -275,9 +275,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): ) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - # shift 3.0 is the default value - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): @@ -304,7 +303,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( - args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + args, latents, noise, accelerator.device, weight_dtype ) # ensure the hidden state will require grad From f264f4091f734b4e4011257b8571ef97315a1343 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 7 Nov 2024 21:30:31 +0900 Subject: [PATCH 295/356] Update README.md --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index dba76a3c..9273fc8f 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Nov 7, 2024: - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) Oct 31, 2024: @@ -693,7 +695,8 @@ Other options are described below. - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. - `--training_shift` is the shift value for the training distribution of timesteps. - + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) Technical details of multi-resolution training for SD3.5M: From 5eb6d209d5b28d43bf611e0934297703eb041d07 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Thu, 7 Nov 2024 20:33:31 +0800 Subject: [PATCH 296/356] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9273fc8f..fe7c506c 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Nov 7, 2024: - The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure). - The effect of a shift in uniform distribution is shown in the figure below. - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) From e5ac09574928ec02fba5fe78267764d26bb7faa6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 21:39:47 +0900 Subject: [PATCH 297/356] add about dev and sd3 branch to README --- README-ja.md | 4 ++++ README.md | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/README-ja.md b/README-ja.md index 4ae6b233..27cc56c3 100644 --- a/README-ja.md +++ b/README-ja.md @@ -3,6 +3,10 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ [README in English](./README.md) ←更新情報はこちらにあります +開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。 + +FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。 + GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 以下のスクリプトがあります。 diff --git a/README.md b/README.md index 1e7d49af..51ac07a1 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,11 @@ This repository contains training, generation and utility scripts for Stable Dif [日本語版READMEはこちら](./README-ja.md) +The development version is in the `dev` branch. Please check the dev branch for the latest changes. + +FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch. + + For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! This repository contains the scripts for: From 186aa5b97d43700706bd8e986e2d5ac3f5d4c9b7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 7 Nov 2024 22:16:05 +0900 Subject: [PATCH 298/356] fix illeagal block is swapped #1764 --- library/flux_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 48dea4fc..4721fa02 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1077,7 +1077,7 @@ class Flux(nn.Module): def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - start_time = time.perf_counter() + # start_time = time.perf_counter() # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") utils.swap_weight_devices(block_to_cpu, block_to_cuda) # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") @@ -1123,7 +1123,7 @@ class Flux(nn.Module): if block_idx < self.single_blocks_to_swap: block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.blocks_to_swap + block_idx + block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) single_futures[block_idx_to_cuda] = future From b3248a8eefe066e6502b535a19501363ec352974 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:31:05 +0100 Subject: [PATCH 299/356] fix: sort order when getting image size from cache file --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 18d3cf6c..8b5cf214 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1887,7 +1887,7 @@ class DreamBoothDataset(BaseDataset): # make image path to npz path mapping npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths.sort() + npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix npz_path_index = 0 size_set_count = 0 From 8fac3c3b088699f607392694beee76bc0036c8d9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 9 Nov 2024 19:56:02 +0900 Subject: [PATCH 300/356] update README --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 87c81001..14328607 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Nov 9, 2024: + +- Fixed an issue where the image size could not be obtained when caching latents was enabled and a specific file name existed, causing the latent size to be incorrect. See PR [#1770](https://github.com/kohya-ss/sd-scripts/pull/1770) for details. Thanks to feffy380! + Nov 7, 2024: - The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! From 26bd4540a6cc7e62100f4901507d8fa0c5a7f78b Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 11 Nov 2024 09:25:28 +0800 Subject: [PATCH 301/356] init --- library/train_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8b5cf214..7f396d36 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1405,11 +1405,11 @@ class BaseDataset(torch.utils.data.Dataset): image_size = imagesize.get(image_path) if image_size[0] <= 0: # imagesize doesn't work for some images, so use cv2 - img = cv2.imread(image_path) - if img is not None: - image_size = (img.shape[1], img.shape[0]) - else: - logger.warning(f"failed to get image size: {image_path}") + try: + with Image.open(image_path) as img: + image_size = img.size + except Exception as e: + logger.warning(f"failed to get image size: {image_path}, error: {e}") image_size = (0, 0) return image_size From 02bd76e6c719ad85c108a177405846c5c958bd78 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 11 Nov 2024 21:15:36 +0900 Subject: [PATCH 302/356] Refactor block swapping to utilize custom offloading utilities --- flux_train.py | 222 ++++++++--------------------- library/custom_offloading_utils.py | 216 ++++++++++++++++++++++++++++ library/flux_models.py | 115 +++------------ 3 files changed, 293 insertions(+), 260 deletions(-) create mode 100644 library/custom_offloading_utils.py diff --git a/flux_train.py b/flux_train.py index afddc897..02dede45 100644 --- a/flux_train.py +++ b/flux_train.py @@ -295,7 +295,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - flux.enable_block_swap(args.blocks_to_swap) + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # load VAE here if not cached @@ -338,15 +338,15 @@ def train(args): # determine target layer and block index for each parameter block_type = "other" # double, single or other if np[0].startswith("double_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "double" elif np[0].startswith("single_blocks"): - block_idx = int(np[0].split(".")[1]) + block_index = int(np[0].split(".")[1]) block_type = "single" else: - block_idx = -1 + block_index = -1 - param_group_key = (block_type, block_idx) + param_group_key = (block_type, block_index) if param_group_key not in param_group: param_group[param_group_key] = [] param_group[param_group_key].append(p) @@ -466,123 +466,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, block_id): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Backward: Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to CUDA") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Backward: Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") - return bidx_to_cpu, bidx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_id] = thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_blocks_move(block_id, futures): - if block_id not in futures: - return - # print(f"Backward: Wait for block {block_id}") - # start_time = time.perf_counter() - future = futures.pop(block_id) - _, bidx_to_cuda = future.result() - assert block_id[1] == bidx_to_cuda, f"Block index mismatch: {block_id[1]} != {bidx_to_cuda}" - # print(f"Backward: Waited for block {block_id}: {time.perf_counter()-start_time:.2f}s") - # print(f"Backward: Synchronized: {time.perf_counter()-start_time:.2f}s") - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - handled_block_ids = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - if double_blocks_to_swap > 0 or single_blocks_to_swap > 0: - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if is_double and double_blocks_to_swap > 0 or is_single and single_blocks_to_swap > 0: - block_idx = int(param_name.split(".")[1]) - block_id = (is_double, block_idx) # double or single, block index - if block_id not in handled_block_ids: - # swap following (already backpropagated) block - handled_block_ids.add(block_id) - - # if n blocks were already backpropagated - if is_double: - num_blocks = num_double_blocks - blocks_to_swap = double_blocks_to_swap - else: - num_blocks = num_single_blocks - blocks_to_swap = single_blocks_to_swap - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = num_blocks - block_idx - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - is_dbl, bidx_to_cpu, bidx_to_cuda, bidx_to_wait, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - # print( - # f"Backward: Block {is_dbl}, {bidx_to_cpu}, {bidx_to_cuda}, {bidx_to_wait}, {swpng}, {wtng}" - # ) - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - flux.double_blocks if is_dbl else flux.single_blocks, - (is_dbl, bidx_to_cuda), # wait for this block - ) - if wtng: - wait_blocks_move((is_dbl, bidx_to_wait), futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - is_double, block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -601,66 +499,66 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - num_block_units = num_double_blocks + num_single_blocks // 2 - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) - # swap blocks if necessary - if blocks_to_swap and (btype == "double" or (btype == "single" and bidx % 2 == 0)): - unit_idx = bidx if btype == "double" else num_double_blocks + bidx // 2 - num_blocks_propagated = num_block_units - unit_idx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = unit_idx > 0 and unit_idx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_block_units - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - flux.double_blocks, - flux.single_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = unit_idx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 + # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook + if is_swapping_blocks: + import library.custom_offloading_utils as custom_offloading_utils + + num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) + num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) + double_blocks_to_swap = args.blocks_to_swap // 2 + single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 + + offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) + offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) + + param_name_pairs = [] + if not args.blockwise_fused_optimizers: + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + param_name_pairs.extend(zip(param_group["params"], param_name_group)) + else: + # named_parameters is a list of (name, parameter) pairs + param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) + + for parameter, param_name in param_name_pairs: + if not parameter.requires_grad: + continue + + is_double = param_name.startswith("double_blocks") + is_single = param_name.startswith("single_blocks") + if not is_double and not is_single: + continue + + block_index = int(param_name.split(".")[1]) + if is_double: + blocks = flux.double_blocks + offloader = offloader_double + else: + blocks = flux.single_blocks + offloader = offloader_single + + grad_hook = offloader.create_grad_hook(blocks, block_index) + if grad_hook is not None: + parameter.register_post_accumulate_grad_hook(grad_hook) + # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 00000000..33a41300 --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,216 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from library.device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices(block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class TrainOffloader(Offloader): + """ + supports backward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + self.hook_added = set() + + def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + if block_index in self.hook_added: + return None + self.hook_added.add(block_index) + + # -1 for 0-based index, -1 for current block is not fully backpropagated yet + num_blocks_propagated = self.num_blocks - block_index - 2 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + if self.debug: + print( + f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" + ) + if swapping: + + def grad_hook(tensor: torch.Tensor): + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + + return grad_hook + + else: + + def grad_hook(tensor: torch.Tensor): + self._wait_blocks_move(block_idx_to_wait) + + return grad_hook + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/flux_models.py b/library/flux_models.py index 4721fa02..e0bee160 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -18,6 +18,7 @@ import torch from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint +from library import custom_offloading_utils # USE_REENTRANT = True @@ -923,7 +924,8 @@ class Flux(nn.Module): self.cpu_offload_checkpointing = False self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader_double = None + self.offloader_single = None self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) @@ -963,16 +965,16 @@ class Flux(nn.Module): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - self.double_blocks_to_swap = num_blocks // 2 - self.single_blocks_to_swap = (num_blocks - self.double_blocks_to_swap) * 2 - print( - f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {self.double_blocks_to_swap}, single blocks: {self.single_blocks_to_swap}." - ) + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) + self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) + self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) def move_to_device_except_swap_blocks(self, device: torch.device): # assume model is on cpu. do not move blocks to device to reduce temporary memory usage @@ -988,56 +990,11 @@ class Flux(nn.Module): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - # def get_block_unit(self, index: int): - # if index < len(self.double_blocks): - # return (self.double_blocks[index],) - # else: - # index -= len(self.double_blocks) - # index *= 2 - # return self.single_blocks[index], self.single_blocks[index + 1] - - # def get_unit_index(self, is_double: bool, index: int): - # if is_double: - # return index - # else: - # return len(self.double_blocks) + index // 2 - def prepare_block_swap_before_forward(self): - # # make: first n blocks are on cuda, and last n blocks are on cpu - # if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # # raise ValueError("Block swap is not enabled.") - # return - # for i in range(self.num_block_units - self.blocks_to_swap): - # for b in self.get_block_unit(i): - # b.to(self.device) - # for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units): - # for b in self.get_block_unit(i): - # b.to("cpu") - # clean_memory_on_device(self.device) - - # all blocks are on device, but some weights are on cpu - # make first n blocks weights on device, and last n blocks weights on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - - for b in self.double_blocks[0 : self.num_double_blocks - self.double_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.double_blocks[self.num_double_blocks - self.double_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) - - for b in self.single_blocks[0 : self.num_single_blocks - self.single_blocks_to_swap]: - b.to(self.device) - utils.weighs_to_device(b, self.device) # make sure weights are on device - for b in self.single_blocks[self.num_single_blocks - self.single_blocks_to_swap :]: - b.to(self.device) # move block to device first - utils.weighs_to_device(b, "cpu") # make sure weights are on cpu - torch.cuda.synchronize() - clean_memory_on_device(self.device) + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) def forward( self, @@ -1073,59 +1030,21 @@ class Flux(nn.Module): for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) else: - # device = self.device - - def submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # start_time = time.perf_counter() - # print(f"Moving {bidx_to_cpu} to cpu and {bidx_to_cuda} to cuda.") - utils.swap_weight_devices(block_to_cpu, block_to_cuda) - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - return block_idx_to_cpu, block_idx_to_cuda # , event - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # print(f"{block_idx} move blocks took {time.perf_counter() - start_time:.2f} seconds") - - double_futures = {} for block_idx, block in enumerate(self.double_blocks): - # print(f"Double block {block_idx}") - wait_for_blocks_move(block_idx, double_futures) + self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.double_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_double_blocks - self.double_blocks_to_swap + block_idx - future = submit_move_blocks(self.double_blocks, block_idx_to_cpu, block_idx_to_cuda) - double_futures[block_idx_to_cuda] = future + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) img = torch.cat((txt, img), 1) - single_futures = {} for block_idx, block in enumerate(self.single_blocks): - # print(f"Single block {block_idx}") - wait_for_blocks_move(block_idx, single_futures) + self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_idx < self.single_blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = self.num_single_blocks - self.single_blocks_to_swap + block_idx - future = submit_move_blocks(self.single_blocks, block_idx_to_cpu, block_idx_to_cuda) - single_futures[block_idx_to_cuda] = future + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) img = img[:, txt.shape[1] :, ...] From 3fe94b058a039b69b6b178bc086e200e40bfa887 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 08:09:07 +0900 Subject: [PATCH 303/356] update comment --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 7f396d36..a5d6fdd2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1404,7 +1404,7 @@ class BaseDataset(torch.utils.data.Dataset): # return imagesize.get(image_path) image_size = imagesize.get(image_path) if image_size[0] <= 0: - # imagesize doesn't work for some images, so use cv2 + # imagesize doesn't work for some images, so use PIL as a fallback try: with Image.open(image_path) as img: image_size = img.size From cde90b8903870b6b28dae274d07ed27978055e3c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 08:49:05 +0900 Subject: [PATCH 304/356] feat: implement block swapping for FLUX.1 LoRA (WIP) --- flux_train.py | 2 +- flux_train_network.py | 33 ++++++++++++++++++++++++ library/custom_offloading_utils.py | 40 +++++++++++++++++++++++++++++- library/flux_models.py | 8 ++++-- train_network.py | 9 ++++++- 5 files changed, 87 insertions(+), 5 deletions(-) diff --git a/flux_train.py b/flux_train.py index 02dede45..346fe8fb 100644 --- a/flux_train.py +++ b/flux_train.py @@ -519,7 +519,7 @@ def train(args): num_parameters_per_group[opt_idx] += 1 # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if is_swapping_blocks: + if False: # is_swapping_blocks: import library.custom_offloading_utils as custom_offloading_utils num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) diff --git a/flux_train_network.py b/flux_train_network.py index 2b71a897..376cc159 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -25,6 +25,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -78,6 +79,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.split_mode: model = self.prepare_split_model(model, weight_dtype, accelerator) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l.eval() @@ -285,6 +292,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) if not args.split_mode: + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() flux_train_utils.sample_images( accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) @@ -539,6 +548,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + + return flux + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() @@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) + + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 33a41300..70da9390 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -183,9 +183,47 @@ class ModelOffloader(Offloader): supports forward offloading """ - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): super().__init__(num_blocks, blocks_to_swap, device, debug) + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return diff --git a/library/flux_models.py b/library/flux_models.py index e0bee160..4fa27252 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,8 +970,12 @@ class Flux(nn.Module): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device) - self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device) + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." ) diff --git a/train_network.py b/train_network.py index b90aa420..d70f14ad 100644 --- a/train_network.py +++ b/train_network.py @@ -18,6 +18,7 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed +from accelerate import Accelerator from diffusers import DDPMScheduler from library import deepspeed_utils, model_util, strategy_base, strategy_sd @@ -272,6 +273,11 @@ class NetworkTrainer: def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass @@ -627,7 +633,8 @@ class NetworkTrainer: training_model = ds_model else: if train_unet: - unet = accelerator.prepare(unet) + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: From 2cb7a6db02ae001355f4830581b9fc2ffffe01c6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 12 Nov 2024 21:39:13 +0900 Subject: [PATCH 305/356] feat: add block swap for FLUX.1/SD3 LoRA training --- README.md | 212 ++++++---------------------- flux_train.py | 56 +------- flux_train_network.py | 95 +++++++------ library/custom_offloading_utils.py | 75 ++++------ library/flux_models.py | 19 ++- library/flux_train_utils.py | 48 +------ library/sd3_models.py | 71 +++------- library/sd3_train_utils.py | 49 +------ library/train_util.py | 74 +++++++++- sd3_train.py | 180 +++-------------------- sd3_train_network.py | 30 ++++ tools/cache_latents.py | 1 + tools/cache_text_encoder_outputs.py | 1 + train_network.py | 6 +- 14 files changed, 288 insertions(+), 629 deletions(-) diff --git a/README.md b/README.md index 14328607..1e63b583 100644 --- a/README.md +++ b/README.md @@ -14,150 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates -Nov 9, 2024: +Nov 12, 2024: -- Fixed an issue where the image size could not be obtained when caching latents was enabled and a specific file name existed, causing the latent size to be incorrect. See PR [#1770](https://github.com/kohya-ss/sd-scripts/pull/1770) for details. Thanks to feffy380! - -Nov 7, 2024: - -- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233! - - Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details. - - Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`). - - A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled (training more on image details), and if more than 1.0, the side closer to noise is more sampled (training more on overall structure). - - The effect of a shift in uniform distribution is shown in the figure below. - - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) - -Oct 31, 2024: - -- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details. - -Oct 19, 2024: - -- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. - - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. - - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". - - The weights of DOP is specified by `--prior_loss_weight` option (not dataset config). - - The appropriate value is still unknown. For FLUX, according to the comments in the [PR](https://github.com/kohya-ss/sd-scripts/pull/1710), the value may be 1 (thanks to dxqbYD!). For SDXL, a larger value may be needed (10-100 may be good starting points). - - It may be good to adjust the value so that the loss is about half to three-quarters of the loss when DOP is not applied. -``` -[[datasets.subsets]] -image_dir = "path/to/image/dir" -num_repeats = 1 -is_reg = true -custom_attributes.diff_output_preservation = true # Add this -``` - - -Oct 13, 2024: - -- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. - -- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU. - - Please make sure that `--highvram` and `--vae_batch_size` are specified correctly. If you have enough VRAM, you can increase the batch size to speed up the caching. - - `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching. - - Multi-threading is also implemented for caching of latents. This may speed up the caching process about 5% (depends on the environment). - - `tools/cache_latents.py` and `tools/cache_text_encoder_outputs.py` also have been updated to support multi-GPU caching. -- `--skip_cache_check` option is added to each training script. - - When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped. - - Specify this option if you have a large number of cache files and the consistency check takes time. - - Even if this option is specified, the cache will be created if the file does not exist. - - `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead. - -Oct 12, 2024 (update 1): - -- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models. - - A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38. - - The model is automatically determined based on the keys in *.safetensors. - - Specifications for compact model safetensors: - - Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks. - - LoRA training is unverified. - - The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified. - -Oct 12, 2024: - -- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - - In simple tests, SDXL and FLUX.1 LoRA training worked. FLUX.1 fine-tuning did not work, probably due to a PyTorch-related error. Other scripts are unverified. - - Set up multi-GPU training with `accelerate config`. - - Specify `--rdzv_backend=c10d` when launching `accelerate launch`. You can also edit `config.yaml` directly. - ``` - accelerate launch --rdzv_backend=c10d sdxl_train_network.py ... - ``` - - In multi-GPU training, the memory of multiple GPUs is not integrated. In other words, even if you have two 12GB VRAM GPUs, you cannot train the model that requires 24GB VRAM. Training that can be done with 12GB VRAM is executed at (up to) twice the speed. - -Oct 11, 2024: -- ControlNet training for SDXL has been implemented in this branch. Please use `sdxl_train_control_net.py`. - - For details on defining the dataset, see [here](docs/train_lllite_README.md#creating-a-dataset-configuration-file). - - The learning rate for the copy part of the U-Net is specified by `--learning_rate`. The learning rate for the added modules in ControlNet is specified by `--control_net_lr`. The optimal value is still unknown, but try around U-Net `1e-5` and ControlNet `1e-4`. - - If you want to generate sample images, specify the control image as `--cn path/to/control/image`. - - The trained weights are automatically converted and saved in Diffusers format. It should be available in ComfyUI. -- Weighting of prompts (captions) during training in SDXL is now supported (e.g., `(some text)`, `[some text]`, `(some text:1.4)`, etc.). The function is enabled by specifying `--weighted_captions`. - - The default is `False`. It is same as before, and the parentheses are used as normal text. - - If `--weighted_captions` is specified, please use `\` to escape the parentheses in the prompt. For example, `\(some text:1.4\)`. - -Oct 6, 2024: -- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified. -- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format. - -Sep 26, 2024: -The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. - - -Sep 18, 2024 (update 1): -Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now. - -Sep 18, 2024: - -- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details. - - Details of the schedule-free optimizer can be found in [facebookresearch/schedule_free](https://github.com/facebookresearch/schedule_free). - - `schedulefree` is added to the dependencies. Please update the library if necessary. - - AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`. - - Wrapper classes are not available for now. - - These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch. - -Sep 16, 2024: - - Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details. - -Sep 15, 2024: - -Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. - -The implementation is based on 2kpr's code. Thanks to 2kpr! - -Sep 14, 2024: -- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details. -- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details. - -Sep 11, 2024: -Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev! - -Sep 10, 2024: -In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. - -Sep 9, 2024: -Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. - -Sep 5, 2024 (update 1): - -Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. - -Sep 5, 2024: - -The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details. - -Sep 4, 2024: -- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details. -- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching. -- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training. - -Sep 1, 2024: -- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds! - - This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`. - -Aug 29, 2024: -Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated. +- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. +- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved. +- There may be bugs due to the significant changes. Feedback is welcome. ## FLUX.1 training @@ -190,7 +51,8 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t --pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_module networks.lora_flux --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name @@ -198,23 +60,39 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t ``` (The command is multi-line for readability. Please combine it into one line.) -The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +When training LoRA for Text Encoder (without `--network_train_unet_only`), more VRAM is required. Please refer to the settings below to reduce VRAM usage. + +__Options for GPUs with less VRAM:__ + +By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + +Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. + +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. + +Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: ``` --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below: +The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. + +The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: ``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--blocks_to_swap 16 ``` -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. +For GPUs with less than 10GB of VRAM, it is recommended to use an fp8 checkpoint for T5XXL. You can download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (please use without `scaled`). -We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. +10GB VRAM GPUs will work with 22 blocks swapped, and 8GB VRAM GPUs will work with 28 blocks swapped. -The trained LoRA model can be used with ComfyUI. +__`--split_mode` is deprecated. This option is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. If this option is specified and `--blocks_to_swap` is not specified, `--blocks_to_swap 18` is automatically enabled.__ #### Key Options for FLUX.1 LoRA training @@ -239,6 +117,7 @@ There are many unknown points in FLUX.1 training, so some settings can be specif - `additive`: add to noisy input - `sigma_scaled`: apply sigma scaling, same as SD3 - `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). +- `--blocks_to_swap`. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. @@ -426,9 +305,9 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, ` `--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. -`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. The recommended maximum value is 36. +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). The maximum value is 35. -`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. This option cannot be used with `--blocks_to_swap`. All these options are experimental and may change in the future. @@ -448,13 +327,13 @@ There are two possible ways to use block swap. It is unknown which is better. 2. Swap many blocks to increase the batch size and shorten the training speed per data. - For example, swapping 20 blocks seems to increase the batch size to about 6. In this case, the training speed per data will be relatively faster than 1. + For example, swapping 35 blocks seems to increase the batch size to about 5. In this case, the training speed per data will be relatively faster than 1. #### Training with <24GB VRAM GPUs Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. -T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. #### Key Features for FLUX.1 fine-tuning @@ -465,17 +344,19 @@ T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement f - Since the transfer between CPU and GPU takes time, the training will be slower. - `--blocks_to_swap` specify the number of blocks to swap. - About 640MB of memory can be saved per block. - - Since the memory usage of one double block and two single blocks is almost the same, the transfer of single blocks is done in units of two. For example, consider the case of `--blocks_to_swap 6`. - - Before the forward pass, all double blocks and 26 (=38-12) single blocks are on the GPU. The last 12 single blocks are on the CPU. - - In the forward pass, the 6 double blocks that have finished calculation (the first 6 blocks) are transferred to the CPU, and the 12 single blocks to be calculated (the last 12 blocks) are transferred to the GPU. - - The same is true for the backward pass, but in reverse order. The 12 single blocks that have finished calculation are transferred to the CPU, and the 6 double blocks to be calculated are transferred to the GPU. - - After the backward pass, the blocks are back to their original locations. + - (Update 1: Nov 12, 2024) + - The maximum number of blocks that can be swapped is 35. + - We are exchanging only the data of the weights (weight.data) in reference to the implementation of OneTrainer (thanks to OneTrainer). However, the mechanism of the exchange is a custom implementation. + - Since it takes time to free CUDA memory (torch.cuda.empty_cache()), we reuse the CUDA memory allocated to weight.data as it is and exchange the weights between modules. + - This shortens the time it takes to exchange weights between modules. + - Since the weights must be almost identical to be exchanged, FLUX.1 exchanges the weights between double blocks and single blocks. + - In SD3, all blocks are similar, but some weights are different, so there are weights that always remain on the GPU. 2. Sample Image Generation: - Sample image generation during training is now supported. - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. - - Note: It will be very slow when `--split_mode` is specified. + - Note: It will be very slow when `--blocks_to_swap` is specified. 3. Experimental Memory-Efficient Saving: - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). @@ -621,20 +502,19 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_tr --pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 ---network_module networks.lora_sd3 --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 +--network_module networks.lora_sd3 --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name sd3-lora-name ``` (The command is multi-line for readability. Please combine it into one line.) -The training can be done with 12GB VRAM GPUs with Adafactor optimizer. Please use settings like below: +Like FLUX.1 training, the `--blocks_to_swap` option for memory reduction is available. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M. -``` ---optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 -``` +Adafactor optimizer is also available. -`--cpu_offload_checkpointing` and `--split_mode` are not available for SD3 LoRA training. +`--cpu_offload_checkpointing` option is not available. We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. diff --git a/flux_train.py b/flux_train.py index 346fe8fb..ad2c7722 100644 --- a/flux_train.py +++ b/flux_train.py @@ -78,6 +78,10 @@ def train(args): ) args.gradient_checkpointing = True + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -518,47 +522,6 @@ def train(args): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook - if False: # is_swapping_blocks: - import library.custom_offloading_utils as custom_offloading_utils - - num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks) - num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks) - double_blocks_to_swap = args.blocks_to_swap // 2 - single_blocks_to_swap = (args.blocks_to_swap - double_blocks_to_swap) * 2 - - offloader_double = custom_offloading_utils.TrainOffloader(num_double_blocks, double_blocks_to_swap, accelerator.device) - offloader_single = custom_offloading_utils.TrainOffloader(num_single_blocks, single_blocks_to_swap, accelerator.device) - - param_name_pairs = [] - if not args.blockwise_fused_optimizers: - for param_group, param_name_group in zip(optimizer.param_groups, param_names): - param_name_pairs.extend(zip(param_group["params"], param_name_group)) - else: - # named_parameters is a list of (name, parameter) pairs - param_name_pairs.extend([(p, n) for n, p in flux.named_parameters()]) - - for parameter, param_name in param_name_pairs: - if not parameter.requires_grad: - continue - - is_double = param_name.startswith("double_blocks") - is_single = param_name.startswith("single_blocks") - if not is_double and not is_single: - continue - - block_index = int(param_name.split(".")[1]) - if is_double: - blocks = flux.double_blocks - offloader = offloader_double - else: - blocks = flux.single_blocks - offloader = offloader_single - - grad_hook = offloader.create_grad_hook(blocks, block_index) - if grad_hook is not None: - parameter.register_post_accumulate_grad_hook(grad_hook) - # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -827,6 +790,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( @@ -851,16 +815,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks (~640MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", - ) parser.add_argument( "--double_blocks_to_swap", type=int, diff --git a/flux_train_network.py b/flux_train_network.py index 376cc159..9bcd5928 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -52,10 +52,23 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") - assert not args.split_mode or not args.cpu_offload_checkpointing, ( - "split_mode and cpu_offload_checkpointing cannot be used together" - " / split_modeとcpu_offload_checkpointingは同時に使用できません" - ) + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases train_dataset_group.verify_bucket_reso_steps(32) # TODO check this @@ -75,9 +88,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") elif model.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) - if args.split_mode: - model = self.prepare_split_model(model, weight_dtype, accelerator) + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if self.is_swapping_blocks: @@ -108,6 +127,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + """ def prepare_split_model(self, model, weight_dtype, accelerator): from accelerate import init_empty_weights @@ -144,6 +164,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): logger.info("split model prepared") return flux_lower + """ def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) @@ -291,14 +312,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders = text_encoder # for compatibility text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) - if not args.split_mode: - if self.is_swapping_blocks: - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - return + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + """ class FluxUpperLowerWrapper(torch.nn.Module): def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): super().__init__() @@ -325,6 +344,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs ) clean_memory_on_device(accelerator.device) + """ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -383,20 +403,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t5_attn_mask = None def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ else: # split forward to reduce memory usage assert network.train_blocks == "single", "train_blocks must be single for split mode" @@ -430,6 +451,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): vec.requires_grad_(True) pe.requires_grad_(True) model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ return model_pred @@ -558,30 +580,23 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): flux: flux_models.Flux = unet flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() return flux def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( "--split_mode", action="store_true", - help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - ) - - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", ) return parser diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 70da9390..84c2b743 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -16,13 +16,29 @@ def synchronize_device(device: torch.device): torch.mps.synchronize() -def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ weight_swap_jobs = [] - for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): - if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: - weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value @@ -92,7 +108,7 @@ class Offloader: def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): if self.cuda_available: - swap_weight_devices(block_to_cpu, block_to_cuda) + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) else: swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) @@ -132,52 +148,6 @@ class Offloader: print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") -class TrainOffloader(Offloader): - """ - supports backward offloading - """ - - def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): - super().__init__(num_blocks, blocks_to_swap, device, debug) - self.hook_added = set() - - def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: - if block_index in self.hook_added: - return None - self.hook_added.add(block_index) - - # -1 for 0-based index, -1 for current block is not fully backpropagated yet - num_blocks_propagated = self.num_blocks - block_index - 2 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap - waiting = block_index > 0 and block_index <= self.blocks_to_swap - - if not swapping and not waiting: - return None - - # create hook - block_idx_to_cpu = self.num_blocks - num_blocks_propagated - block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_index - 1 - - if self.debug: - print( - f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}" - ) - if swapping: - - def grad_hook(tensor: torch.Tensor): - self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) - - return grad_hook - - else: - - def grad_hook(tensor: torch.Tensor): - self._wait_blocks_move(block_idx_to_wait) - - return grad_hook - - class ModelOffloader(Offloader): """ supports forward offloading @@ -228,6 +198,9 @@ class ModelOffloader(Offloader): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return + if self.debug: + print("Prepare block devices before forward") + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: b.to(self.device) weighs_to_device(b, self.device) # make sure weights are on device diff --git a/library/flux_models.py b/library/flux_models.py index 4fa27252..fa3c7ad2 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -970,11 +970,16 @@ class Flux(nn.Module): double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1061,10 +1066,11 @@ class Flux(nn.Module): return img +""" class FluxUpper(nn.Module): - """ + "" Transformer model for flow matching on sequences. - """ + "" def __init__(self, params: FluxParams): super().__init__() @@ -1168,9 +1174,9 @@ class FluxUpper(nn.Module): class FluxLower(nn.Module): - """ + "" Transformer model for flow matching on sequences. - """ + "" def __init__(self, params: FluxParams): super().__init__() @@ -1228,3 +1234,4 @@ class FluxLower(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img +""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index fa673a2f..d90644a2 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -257,14 +257,9 @@ def sample_image_inference( wandb_tracker = accelerator.get_tracker("wandb") import wandb + # not to commit images to avoid inconsistency between training and logging steps - wandb_tracker.log( - {f"sample_{i}": wandb.Image( - image, - caption=prompt # positive prompt as a caption - )}, - commit=False - ) + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption def time_shift(mu: float, sigma: float, t: torch.Tensor): @@ -324,7 +319,7 @@ def denoise( ) img = img + (t_prev - t_curr) * pred - + model.prepare_block_swap_before_forward() return img @@ -549,44 +544,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): action="store_true", help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", ) - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - parser.add_argument( - "--text_encoder_batch_size", - type=int, - default=None, - help="text encoder batch size (default: None, use dataset's batch size)" - + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", - ) - parser.add_argument( - "--disable_mmap_load_safetensors", - action="store_true", - help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", - ) - # copy from Diffusers - parser.add_argument( - "--weighting_scheme", - type=str, - default="none", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - ) - parser.add_argument( - "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." - ) - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", - ) parser.add_argument( "--guidance_scale", type=float, diff --git a/library/sd3_models.py b/library/sd3_models.py index 89225fe4..8b90205d 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -18,6 +18,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from library import custom_offloading_utils from library.device_utils import clean_memory_on_device from .utils import setup_logging @@ -862,7 +863,8 @@ class MMDiT(nn.Module): # self.initialize_weights() self.blocks_to_swap = None - self.thread_pool: Optional[ThreadPoolExecutor] = None + self.offloader = None + self.num_blocks = len(self.joint_blocks) def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): self.use_scaled_pos_embed = use_scaled_pos_embed @@ -1055,14 +1057,20 @@ class MMDiT(nn.Module): # ) return spatial_pos_embed - def enable_block_swap(self, num_blocks: int): + def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - n = 1 # async block swap. 1 is enough - self.thread_pool = ThreadPoolExecutor(max_workers=n) + assert ( + self.blocks_to_swap <= self.num_blocks - 2 + ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." + + self.offloader = custom_offloading_utils.ModelOffloader( + self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + ) + print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage if self.blocks_to_swap: save_blocks = self.joint_blocks self.joint_blocks = None @@ -1073,16 +1081,9 @@ class MMDiT(nn.Module): self.joint_blocks = save_blocks def prepare_block_swap_before_forward(self): - # make: first n blocks are on cuda, and last n blocks are on cpu if self.blocks_to_swap is None or self.blocks_to_swap == 0: - # raise ValueError("Block swap is not enabled.") return - num_blocks = len(self.joint_blocks) - for i in range(num_blocks - self.blocks_to_swap): - self.joint_blocks[i].to(self.device) - for i in range(num_blocks - self.blocks_to_swap, num_blocks): - self.joint_blocks[i].to("cpu") - clean_memory_on_device(self.device) + self.offloader.prepare_block_devices_before_forward(self.joint_blocks) def forward( self, @@ -1122,57 +1123,19 @@ class MMDiT(nn.Module): if self.register_length > 0: context = torch.cat( - ( - einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), - default(context, torch.Tensor([]).type_as(x)), - ), - 1, + (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 ) if not self.blocks_to_swap: for block in self.joint_blocks: context, x = block(context, x, c) else: - futures = {} - - def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): - # print(f"Moving {bidx_to_cpu} to cpu.") - block_to_cpu.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Moving {bidx_to_cuda} to cuda.") - block_to_cuda.to(self.device, non_blocking=True) - - torch.cuda.synchronize() - # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") - return block_idx_to_cpu, block_idx_to_cuda - - block_to_cpu = self.joint_blocks[block_idx_to_cpu] - block_to_cuda = self.joint_blocks[block_idx_to_cuda] - # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") - return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) - - def wait_for_blocks_move(block_idx, ftrs): - if block_idx not in ftrs: - return - # print(f"Waiting for move blocks: {block_idx}") - # start_time = time.perf_counter() - ftr = ftrs.pop(block_idx) - ftr.result() - # torch.cuda.synchronize() - # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") - for block_idx, block in enumerate(self.joint_blocks): - wait_for_blocks_move(block_idx, futures) + self.offloader.wait_for_block(block_idx) context, x = block(context, x, c) - if block_idx < self.blocks_to_swap: - block_idx_to_cpu = block_idx - block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx - future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) - futures[block_idx_to_cuda] = future + self.offloader.submit_move_blocks(self.joint_blocks, block_idx) x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 38f3c25f..c4079884 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -142,27 +142,6 @@ def save_sd3_model_on_epoch_end_or_stepwise( def add_sd3_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - parser.add_argument( - "--text_encoder_batch_size", - type=int, - default=None, - help="text encoder batch size (default: None, use dataset's batch size)" - + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", - ) - parser.add_argument( - "--disable_mmap_load_safetensors", - action="store_true", - help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", - ) - parser.add_argument( "--clip_l", type=str, @@ -253,32 +232,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser): " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", ) - # Dependencies of Diffusers noise sampler has been removed for clarity. - parser.add_argument( - "--weighting_scheme", - type=str, - default="uniform", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"], - help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム", - ) - parser.add_argument( - "--logit_mean", - type=float, - default=0.0, - help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均", - ) - parser.add_argument( - "--logit_std", - type=float, - default=1.0, - help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd", - ) - parser.add_argument( - "--mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効", - ) + # Dependencies of Diffusers noise sampler has been removed for clarity in training + parser.add_argument( "--training_shift", type=float, diff --git a/library/train_util.py b/library/train_util.py index a5d6fdd2..e1dfeecd 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1887,7 +1887,9 @@ class DreamBoothDataset(BaseDataset): # make image path to npz path mapping npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix + npz_paths.sort( + key=lambda item: item.rsplit("_", maxsplit=2)[0] + ) # sort by name excluding resolution and cache_suffix npz_path_index = 0 size_set_count = 0 @@ -3537,8 +3539,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" - + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能", ) parser.add_argument( "--lr_scheduler_timescale", @@ -4027,6 +4029,72 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): ) +def add_dit_training_arguments(parser: argparse.ArgumentParser): + # Text encoder related arguments + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + + # Model loading optimization + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # Training arguments. partial copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"], + help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior" + " / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + + # offloading + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + + def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None. # when `--log_config is enabled, filter out sensitive values from args diff --git a/sd3_train.py b/sd3_train.py index 24ecbfb7..a4fc2eec 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -201,21 +201,6 @@ def train(args): # モデルを読み込む # t5xxl_dtype = weight_dtype - # if args.t5xxl_dtype is not None: - # if args.t5xxl_dtype == "fp16": - # t5xxl_dtype = torch.float16 - # elif args.t5xxl_dtype == "bf16": - # t5xxl_dtype = torch.bfloat16 - # elif args.t5xxl_dtype == "fp32" or args.t5xxl_dtype == "float": - # t5xxl_dtype = torch.float32 - # else: - # raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") - # t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - # clip_dtype = weight_dtype # if not args.train_text_encoder else None - - # if clip_l is not specified, the checkpoint must contain clip_l, so we load state dict here - # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). - # by loading with model_dtype, we can reduce memory usage. model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) if args.clip_l is None: sd3_state_dict = utils.load_safetensors( @@ -384,7 +369,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - mmdit.enable_block_swap(args.blocks_to_swap) + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: # move to accelerator device @@ -611,108 +596,21 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # memory efficient block swapping - - def submit_move_blocks(futures, thread_pool, block_idx_to_cpu, block_idx_to_cuda, blocks, device): - def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda, dvc): - # print(f"Backward: Move block {bidx_to_cpu} to CPU") - block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) - torch.cuda.empty_cache() - - # print(f"Backward: Move block {bidx_to_cuda} to CUDA") - block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) - torch.cuda.synchronize() - # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") - return bidx_to_cpu, bidx_to_cuda - - block_to_cpu = blocks[block_idx_to_cpu] - block_to_cuda = blocks[block_idx_to_cuda] - - futures[block_idx_to_cuda] = thread_pool.submit( - move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda, device - ) - - def wait_blocks_move(block_idx, futures): - if block_idx not in futures: - return - future = futures.pop(block_idx) - future.result() - if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - blocks_to_swap = args.blocks_to_swap - num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) - handled_block_indices = set() - - n = 1 # only asynchronous purpose, no need to increase this number - # n = 2 - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - grad_hook = None - if blocks_to_swap: - is_block = param_name.startswith("joint_blocks") - if is_block: - block_idx = int(param_name.split(".")[1]) - if block_idx not in handled_block_indices: - # swap following (already backpropagated) block - handled_block_indices.add(block_idx) - - # if n blocks were already backpropagated - num_blocks_propagated = num_blocks - block_idx - 1 - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = block_idx > 0 and block_idx <= blocks_to_swap - if swapping or waiting: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - block_idx_to_wait = block_idx - 1 - - # create swap hook - def create_swap_grad_hook( - bidx_to_cpu, bidx_to_cuda, bidx_to_wait, bidx: int, swpng: bool, wtng: bool - ): - def __grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - if swpng: - submit_move_blocks( - futures, - thread_pool, - bidx_to_cpu, - bidx_to_cuda, - mmdit.joint_blocks, - accelerator.device, - ) - if wtng: - wait_blocks_move(bidx_to_wait, futures) - - return __grad_hook - - grad_hook = create_swap_grad_hook( - block_idx_to_cpu, block_idx_to_cuda, block_idx_to_wait, block_idx, swapping, waiting - ) - - if grad_hook is None: - - def __grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None - - grad_hook = __grad_hook + def grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None parameter.register_post_accumulate_grad_hook(grad_hook) @@ -731,59 +629,22 @@ def train(args): num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} - blocks_to_swap = args.blocks_to_swap - num_blocks = len(accelerator.unwrap_model(mmdit).joint_blocks) - - n = 1 # only asynchronous purpose, no need to increase this number - # n = max(1, os.cpu_count() // 2) - thread_pool = ThreadPoolExecutor(max_workers=n) - futures = {} - for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: - block_type, block_idx = block_types_and_indices[opt_idx] - def create_optimizer_hook(btype, bidx): - def optimizer_hook(parameter: torch.Tensor): - # print(f"optimizer_hook: {btype}, {bidx}") - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) - # swap blocks if necessary - if blocks_to_swap and btype == "joint": - num_blocks_propagated = num_blocks - bidx - - swapping = num_blocks_propagated > 0 and num_blocks_propagated <= blocks_to_swap - waiting = bidx > 0 and bidx <= blocks_to_swap - - if swapping: - block_idx_to_cpu = num_blocks - num_blocks_propagated - block_idx_to_cuda = blocks_to_swap - num_blocks_propagated - # print(f"Backward: Swap blocks {block_idx_to_cpu} and {block_idx_to_cuda}") - submit_move_blocks( - futures, - thread_pool, - block_idx_to_cpu, - block_idx_to_cuda, - mmdit.joint_blocks, - accelerator.device, - ) - - if waiting: - block_idx_to_wait = bidx - 1 - wait_blocks_move(block_idx_to_wait, futures) - - return optimizer_hook - - parameter.register_post_accumulate_grad_hook(create_optimizer_hook(block_type, block_idx)) + parameter.register_post_accumulate_grad_hook(grad_hook) parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 @@ -1130,6 +991,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) + train_util.add_dit_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) parser.add_argument( @@ -1190,16 +1052,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", ) - parser.add_argument( - "--blocks_to_swap", - type=int, - default=None, - help="[EXPERIMENTAL] " - "Sets the number of blocks (~640MB) to swap during the forward and backward passes." - "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." - " / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。" - "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", - ) parser.add_argument( "--num_last_block_to_freeze", type=int, diff --git a/sd3_train_network.py b/sd3_train_network.py index bb02c7ac..1726e325 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -51,6 +51,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this # enumerate resolutions from dataset for positional embeddings @@ -83,6 +87,17 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") elif mmdit.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) clip_l = sd3_utils.load_clip_l( args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict @@ -432,9 +447,24 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) return parser diff --git a/tools/cache_latents.py b/tools/cache_latents.py index e2faa58a..c034f949 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -164,6 +164,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7be9ad78..5888b8e3 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -191,6 +191,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dataset_arguments(parser, True, True, True) train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") diff --git a/train_network.py b/train_network.py index d70f14ad..bbf381f9 100644 --- a/train_network.py +++ b/train_network.py @@ -601,8 +601,10 @@ class NetworkTrainer: # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory - logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") - unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) From 2bb0f547d72cd0256cafebd46d0f61fbe54012ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:33:12 +0900 Subject: [PATCH 306/356] update grad hook creation to fix TE lr in sd3 fine tuning --- flux_train.py | 19 ++++++++++++------- library/train_util.py | 1 + sd3_train.py | 15 +++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/flux_train.py b/flux_train.py index ad2c7722..a89e2f13 100644 --- a/flux_train.py +++ b/flux_train.py @@ -80,7 +80,9 @@ def train(args): assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -480,13 +482,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None - parameter.register_post_accumulate_grad_hook(grad_hook) + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers diff --git a/library/train_util.py b/library/train_util.py index e1dfeecd..25cf7640 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5913,6 +5913,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index a4fc2eec..96ec951b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -606,13 +606,16 @@ def train(args): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: - def grad_hook(tensor: torch.Tensor, param_group=param_group): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, param_group) - tensor.grad = None + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None - parameter.register_post_accumulate_grad_hook(grad_hook) + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) elif args.blockwise_fused_optimizers: # prepare for additional optimizers and lr schedulers From 5c5b544b91ac434c12a372cbf1dc123a367ec878 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:35:43 +0900 Subject: [PATCH 307/356] refactor: remove unused prepare_split_model method from FluxNetworkTrainer --- flux_train_network.py | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 9bcd5928..704c4d32 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -127,45 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model - """ - def prepare_split_model(self, model, weight_dtype, accelerator): - from accelerate import init_empty_weights - - logger.info("prepare split model") - with init_empty_weights(): - flux_upper = flux_models.FluxUpper(model.params) - flux_lower = flux_models.FluxLower(model.params) - sd = model.state_dict() - - # lower (trainable) - logger.info("load state dict for lower") - flux_lower.load_state_dict(sd, strict=False, assign=True) - flux_lower.to(dtype=weight_dtype) - - # upper (frozen) - logger.info("load state dict for upper") - flux_upper.load_state_dict(sd, strict=False, assign=True) - - logger.info("prepare upper model") - target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype - flux_upper.to(accelerator.device, dtype=target_dtype) - flux_upper.eval() - - if args.fp8_base: - # this is required to run on fp8 - flux_upper = accelerator.prepare(flux_upper) - - flux_upper.to("cpu") - - self.flux_upper = flux_upper - del model # we don't need model anymore - clean_memory_on_device(accelerator.device) - - logger.info("split model prepared") - - return flux_lower - """ - def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) From fd2d879ac883b8bdf1e03b6ca545c33200dbdff2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 14 Nov 2024 19:43:08 +0900 Subject: [PATCH 308/356] docs: update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1e63b583..81a3199b 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The command to install PyTorch is as follows: ### Recent Updates -Nov 12, 2024: +Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. - During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved. From ccfaa001e74f80798e528b4b3ea6ef811017c07b Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 20:21:28 +0900 Subject: [PATCH 309/356] add flux controlnet base module --- flux_train_control_net.py | 573 ++++++++++++++++++++++++++++++++++++++ flux_train_network.py | 5 +- library/flux_models.py | 257 ++++++++++++++++- library/flux_utils.py | 8 + 4 files changed, 841 insertions(+), 2 deletions(-) create mode 100644 flux_train_control_net.py diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 00000000..704c4d32 --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,573 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + + """ + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def prepare_block_swap_before_forward(self): + pass + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) + """ + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, None, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--split_mode", + action="store_true", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32..0feb9b01 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,7 +125,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + controlnet = flux_utils.load_controlnet() + controlnet.train() + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) diff --git a/library/flux_models.py b/library/flux_models.py index fa3c7ad2..a3bd1974 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1013,6 +1013,8 @@ class Flux(nn.Module): txt_ids: Tensor, timesteps: Tensor, y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: @@ -1031,18 +1033,29 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) if not self.blocks_to_swap: - for block in self.double_blocks: + for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1052,6 +1065,8 @@ class Flux(nn.Module): self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1066,6 +1081,246 @@ class Flux(nn.Module): return img +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(0) # TMP + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks_for_double = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block_idx, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples + + """ class FluxUpper(nn.Module): "" diff --git a/library/flux_utils.py b/library/flux_utils.py index f3093615..678efbc8 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,6 +153,14 @@ def load_ae( return ae +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, From 42f6edf3a886287b99770bc7a8c0bafd3fa03f39 Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 15 Nov 2024 23:48:51 +0900 Subject: [PATCH 310/356] fix for adding controlnet --- flux_train_control_net.py | 1340 +++++++++++++++++++++-------------- flux_train_network.py | 3 - library/flux_train_utils.py | 32 +- library/flux_utils.py | 11 +- 4 files changed, 855 insertions(+), 531 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 704c4d32..8a7be75f 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -1,563 +1,860 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + import argparse +from concurrent.futures import ThreadPoolExecutor import copy import math -import random -from typing import Any, Optional +import os +from multiprocessing import Value +import time +from typing import List, Optional, Tuple, Union +import toml + +from tqdm import tqdm import torch -from accelerate import Accelerator +import torch.nn as nn +from library import utils from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util -import train_network -from library.utils import setup_logging +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) +import library.config_util as config_util -class FluxNetworkTrainer(train_network.NetworkTrainer): - def __init__(self): - super().__init__() - self.sample_prompts_te_outputs = None - self.is_schnell: Optional[bool] = None - self.is_swapping_blocks: bool = False +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - # sdxl_train_util.verify_sdxl_training_args(args) - if args.fp8_base_unet: - args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) - if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" - ) - args.cache_text_encoder_outputs = True + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check - if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True - # prepare CLIP-L/T5XXL training flags - self.train_clip_l = not args.network_train_unet_only - self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True - if args.max_token_length is not None: - logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) - assert ( - args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None - # deprecated split_mode option - if args.split_mode: - if args.blocks_to_swap is not None: + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." - " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" - ) - else: - logger.warning( - "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." - " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" - ) - args.blocks_to_swap = 18 # 18 is safe for most cases - - train_dataset_group.verify_bucket_reso_steps(32) # TODO check this - - def load_target_model(self, args, weight_dtype, accelerator): - # currently offload to cpu for some models - - # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) - loading_dtype = None if args.fp8_base else weight_dtype - - # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors - ) - if args.fp8_base: - # check dtype of model - if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") - elif model.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 FLUX model") - else: - logger.info( - "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." - " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" - ) - model.to(torch.float8_e4m3fn) - - # if args.split_mode: - # model = self.prepare_split_model(model, weight_dtype, accelerator) - - self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 - if self.is_swapping_blocks: - # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. - logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - model.enable_block_swap(args.blocks_to_swap, accelerator.device) - - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - clip_l.eval() - - # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) - if args.fp8_base and not args.fp8_base_unet: - loading_dtype = None # as is - else: - loading_dtype = weight_dtype - - # loading t5xxl to cpu takes a long time, so we should load to gpu in future - t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - t5xxl.eval() - if args.fp8_base and not args.fp8_base_unet: - # check dtype of model - if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: - raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") - elif t5xxl.dtype == torch.float8_e4m3fn: - logger.info("Loaded fp8 T5XXL model") - - ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model - - def get_tokenize_strategy(self, args): - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) - - if args.t5xxl_max_token_length is None: - if is_schnell: - t5xxl_max_token_length = 256 - else: - t5xxl_max_token_length = 512 - else: - t5xxl_max_token_length = args.t5xxl_max_token_length - - logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") - return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) - - def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): - return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] - - def get_latents_caching_strategy(self, args): - latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) - return latents_caching_strategy - - def get_text_encoding_strategy(self, args): - return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) - - def post_process_network(self, args, accelerator, network, text_encoders, unet): - # check t5xxl is trained or not - self.train_t5xxl = network.train_t5xxl - - if self.train_t5xxl and args.cache_text_encoder_outputs: - raise ValueError( - "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" - ) - - def get_models_for_text_encoding(self, args, accelerator, text_encoders): - if args.cache_text_encoder_outputs: - if self.train_clip_l and not self.train_t5xxl: - return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached - else: - return None # no text encoders are needed for encoding because both are cached - else: - return text_encoders # both CLIP-L and T5XXL are needed for encoding - - def get_text_encoders_train_flags(self, args, text_encoders): - return [self.train_clip_l, self.train_t5xxl] - - def get_text_encoder_outputs_caching_strategy(self, args): - if args.cache_text_encoder_outputs: - # if the text encoders is trained, we need tokenization, so is_partial is True - return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - is_partial=self.train_clip_l or self.train_t5xxl, - apply_t5_attn_mask=args.apply_t5_attn_mask, - ) - else: - return None - - def cache_text_encoder_outputs_if_needed( - self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype - ): - if args.cache_text_encoder_outputs: - if not args.lowram: - # メモリ消費を減らす - logger.info("move vae and unet to cpu to save memory") - org_vae_device = vae.device - org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") - clean_memory_on_device(accelerator.device) - - # When TE is not be trained, it will not be prepared so we need to use explicit autocast - logger.info("move text encoders to gpu") - text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 - text_encoders[1].to(accelerator.device) - - if text_encoders[1].dtype == torch.float8_e4m3fn: - # if we load fp8 weights, the model is already fp8, so we use it as is - self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) - else: - # otherwise, we need to convert it to target dtype - text_encoders[1].to(weight_dtype) - - with accelerator.autocast(): - dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) - - # cache sample prompts - if args.sample_prompts is not None: - logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") - - tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() - text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - - prompts = train_util.load_prompts(args.sample_prompts) - sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs - with accelerator.autocast(), torch.no_grad(): - for prompt_dict in prompts: - for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: - if p not in sample_prompts_te_outputs: - logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = tokenize_strategy.tokenize(p) - sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask - ) - self.sample_prompts_te_outputs = sample_prompts_te_outputs - - accelerator.wait_for_everyone() - - # move back to cpu - if not self.is_train_text_encoder(args): - logger.info("move CLIP-L back to cpu") - text_encoders[0].to("cpu") - logger.info("move t5XXL back to cpu") - text_encoders[1].to("cpu") - clean_memory_on_device(accelerator.device) - - if not args.lowram: - logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) - else: - # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device) - - # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - # # get size embeddings - # orig_size = batch["original_sizes_hw"] - # crop_size = batch["crop_top_lefts"] - # target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # # concat embeddings - # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - # return noise_pred - - def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): - text_encoders = text_encoder # for compatibility - text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) - - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) - # return - - """ - class FluxUpperLowerWrapper(torch.nn.Module): - def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): - super().__init__() - self.flux_upper = flux_upper - self.flux_lower = flux_lower - self.target_device = device - - def prepare_block_swap_before_forward(self): - pass - - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): - self.flux_lower.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) - self.flux_upper.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe, txt_attention_mask) - - wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) - clean_memory_on_device(accelerator.device) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs - ) - clean_memory_on_device(accelerator.device) - """ - - def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) - self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) - return noise_scheduler - - def encode_images_to_latents(self, args, accelerator, vae, images): - return vae.encode(images) - - def shift_scale_latents(self, args, latents): - return latents - - def get_noise_pred_and_target( - self, - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet: flux_models.Flux, - network, - weight_dtype, - train_unet, - ): - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype - ) - - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance - # ensure guidance_scale in args is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - noisy_model_input.requires_grad_(True) - for t in text_encoder_conds: - if t is not None and t.dtype.is_floating_point: - t.requires_grad_(True) - img_ids.requires_grad_(True) - guidance_vec.requires_grad_(True) - - # Predict the noise residual - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): - # if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=img, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - """ - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) - """ - - return model_pred - - model_pred = call_dit( - img=packed_noisy_model_input, - img_ids=img_ids, - t5_out=t5_out, - txt_ids=txt_ids, - l_pooled=l_pooled, - timesteps=timesteps, - guidance_vec=guidance_vec, - t5_attn_mask=t5_attn_mask, - ) - - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(): - model_pred_prior = call_dit( - img=packed_noisy_model_input[diff_output_pr_indices], - img_ids=img_ids[diff_output_pr_indices], - t5_out=t5_out[diff_output_pr_indices], - txt_ids=txt_ids[diff_output_pr_indices], - l_pooled=l_pooled[diff_output_pr_indices], - timesteps=timesteps[diff_output_pr_indices], - guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, - t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - - model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) - model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( - args, - model_pred_prior, - noisy_model_input[diff_output_pr_indices], - sigmas[diff_output_pr_indices] if sigmas is not None else None, ) - target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - - return model_pred, target, timesteps, None, weighting - - def post_process_loss(self, loss, args, timesteps, noise_scheduler): - return loss - - def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") - - def update_metadata(self, metadata, args): - metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask - metadata["ss_weighting_scheme"] = args.weighting_scheme - metadata["ss_logit_mean"] = args.logit_mean - metadata["ss_logit_std"] = args.logit_std - metadata["ss_mode_scale"] = args.mode_scale - metadata["ss_guidance_scale"] = args.guidance_scale - metadata["ss_timestep_sampling"] = args.timestep_sampling - metadata["ss_sigmoid_scale"] = args.sigmoid_scale - metadata["ss_model_prediction_type"] = args.model_prediction_type - metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift - - def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) - - def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): - if index == 0: # CLIP-L - return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) - else: # T5XXL - text_encoder.encoder.embed_tokens.requires_grad_(True) - - def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): - if index == 0: # CLIP-L - logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") - text_encoder.to(te_weight_dtype) # fp8 - text_encoder.text_model.embeddings.to(dtype=weight_dtype) - else: # T5XXL - - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: - logger.info(f"T5XXL already prepared for fp8") + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } else: - logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") - text_encoder.to(te_weight_dtype) # fp8 - prepare_fp8(text_encoder, weight_dtype) + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - def prepare_unet_with_accelerator( - self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module - ) -> torch.nn.Module: - if not self.is_swapping_blocks: - return super().prepare_unet_with_accelerator(args, accelerator, unet) + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + flux.requires_grad_(False) + + # load controlnet + controlnet = flux_utils.load_controlnet() + controlnet.requires_grad_(True) + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic # if we doesn't swap blocks, we can move the model to device - flux: flux_models.Flux = unet - flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) - accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage - accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - return flux + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["control_image"].to(accelerator.device), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: - parser = train_network.setup_parser() + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) parser.add_argument( - "--split_mode", + "--mem_eff_save", action="store_true", - # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" - # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", - help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." - " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser @@ -569,5 +866,4 @@ if __name__ == "__main__": train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) - trainer = FluxNetworkTrainer() - trainer.train(args) + train(args) diff --git a/flux_train_network.py b/flux_train_network.py index 0feb9b01..6668012e 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -125,9 +125,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - controlnet = flux_utils.load_controlnet() - controlnet.train() - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet def get_tokenize_strategy(self, args): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d90644a2..cc3bcb0e 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,6 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, + controlnet=None ): if steps == 0: if not args.sample_at_first: @@ -67,6 +68,8 @@ def sample_images( flux = accelerator.unwrap_model(flux) if text_encoders is not None: text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -98,6 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -121,6 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ) torch.set_rng_state(rng_state) @@ -142,6 +147,7 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, + controlnet ): assert isinstance(prompt_dict, dict) # negative_prompt = prompt_dict.get("negative_prompt") @@ -150,7 +156,7 @@ def sample_image_inference( height = prompt_dict.get("height", 512) scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") - # controlnet_image = prompt_dict.get("controlnet_image") + controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) @@ -169,6 +175,9 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 @@ -224,7 +233,7 @@ def sample_image_inference( t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = x.float() x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -301,18 +310,37 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None pred = model( img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, diff --git a/library/flux_utils.py b/library/flux_utils.py index 678efbc8..7b538d13 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,11 +153,14 @@ def load_ae( return ae -def load_controlnet(name, device, transformer=None): - with torch.device(device): +def load_controlnet(): + # TODO + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - if transformer is not None: - controlnet.load_state_dict(transformer.state_dict(), strict=False) + # if transformer is not None: + # controlnet.load_state_dict(transformer.state_dict(), strict=False) return controlnet From e358b118afbc93f63dbb5ab6d2412ec553ea9cd7 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 16 Nov 2024 14:49:29 +0900 Subject: [PATCH 311/356] fix dataloader --- flux_train_control_net.py | 84 ++++++++++++++++++++------------------- library/flux_models.py | 17 ++++---- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 8a7be75f..ee4d0ebf 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -11,31 +11,36 @@ # - Per-block fused optimizer instances import argparse -from concurrent.futures import ThreadPoolExecutor import copy import math import os -from multiprocessing import Value import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value from typing import List, Optional, Tuple, Union + import toml - -from tqdm import tqdm - import torch import torch.nn as nn +from tqdm import tqdm + from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util - -from library.utils import setup_logging, add_logging_arguments +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging setup_logging() import logging @@ -46,10 +51,10 @@ import library.config_util as config_util # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( - ConfigSanitizer, BlueprintGenerator, + ConfigSanitizer, ) -from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss def train(args): @@ -85,7 +90,6 @@ def train(args): ) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -103,7 +107,7 @@ def train(args): if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "conditioing_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -111,31 +115,17 @@ def train(args): ) ) else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension + ) + } + ] + } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) @@ -648,12 +638,12 @@ def train(args): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_cond=batch["control_image"].to(accelerator.device), + controlnet_img=batch["conditioing_image"].to(accelerator.device), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index a3bd1974..b52ea6f0 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,15 +2,15 @@ # license: Apache-2.0 License -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass import math import os import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import Dict, List, Optional, Union from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -18,6 +18,7 @@ import torch from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint + from library import custom_offloading_utils # USE_REENTRANT = True @@ -1251,7 +1252,7 @@ class ControlNetFlux(nn.Module): self, img: Tensor, img_ids: Tensor, - controlnet_cond: Tensor, + controlnet_img: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1264,10 +1265,10 @@ class ControlNetFlux(nn.Module): # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_cond = self.pos_embed_input(controlnet_cond) - img = img + controlnet_cond + controlnet_img = self.input_hint_block(controlnet_img) + controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_img = self.pos_embed_input(controlnet_img) + img = img + controlnet_img vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: From 2a188f07e682ed5dd958821a223d48c17a9aeb83 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 17 Nov 2024 16:12:10 +0900 Subject: [PATCH 312/356] Fix to work DOP with bock swap --- flux_train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32..679db62b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -445,6 +445,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices], From b2660bbe7410d7ffa40906a7a09f84a17139cb46 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 10:24:57 +0000 Subject: [PATCH 313/356] train run --- flux_train_control_net.py | 39 ++++++++++++++++++++++--------------- library/flux_models.py | 30 ++++++++++++++-------------- library/flux_train_utils.py | 2 +- library/flux_utils.py | 2 +- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index ee4d0ebf..205ff6b6 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -103,11 +103,11 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioing_data_dir"] + ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -263,10 +263,11 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) + flux.to(accelerator.device) # load controlnet controlnet = flux_utils.load_controlnet() - controlnet.requires_grad_(True) + controlnet.train() if args.gradient_checkpointing: controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) @@ -443,7 +444,8 @@ def train(args): clean_memory_on_device(accelerator.device) - if args.deepspeed: + # if args.deepspeed: + if True: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -612,8 +614,10 @@ def train(args): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # if args.full_fp16: + # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # TODO: check + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -629,10 +633,10 @@ def train(args): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) # call model l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds @@ -640,10 +644,11 @@ def train(args): t5_attn_mask = None with accelerator.autocast(): + print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_img=batch["conditioing_image"].to(accelerator.device), + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -651,6 +656,8 @@ def train(args): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) + print("control end") + print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -796,7 +803,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) # TODO split this - train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) @@ -852,12 +859,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="controlnet model name or path / controlnetのモデル名またはパス", ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) + # parser.add_argument( + # "--conditioning_data_dir", + # type=str, + # default=None, + # help="conditioning data directory / 条件付けデータのディレクトリ", + # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index b52ea6f0..2fc21db9 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1042,20 +1042,20 @@ class Flux(nn.Module): if not self.blocks_to_swap: for block_idx, block in enumerate(self.double_blocks): img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] img = torch.cat((txt, img), 1) - for block in self.single_blocks: + for block_idx, block in enumerate(self.single_blocks): img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_hidden_states is not None: + if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) @@ -1066,7 +1066,7 @@ class Flux(nn.Module): self.offloader_single.wait_for_block(block_idx) img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - if block_controlnet_single_hidden_states is not None: + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) @@ -1121,14 +1121,14 @@ class ControlNetFlux(nn.Module): mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) - for _ in range(params.depth) + for _ in range(controlnet_depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TMP + for _ in range(0) # TODO ] ) @@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_double.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(controlnet_depth): + for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) @@ -1252,7 +1252,7 @@ class ControlNetFlux(nn.Module): self, img: Tensor, img_ids: Tensor, - controlnet_img: Tensor, + controlnet_cond: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1265,10 +1265,10 @@ class ControlNetFlux(nn.Module): # running on sequences img img = self.img_in(img) - controlnet_img = self.input_hint_block(controlnet_img) - controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_img = self.pos_embed_input(controlnet_img) - img = img + controlnet_img + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: @@ -1283,7 +1283,7 @@ class ControlNetFlux(nn.Module): block_samples = () block_single_samples = () if not self.blocks_to_swap: - for block_idx, block in enumerate(self.double_blocks): + for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) block_samples = block_samples + (img,) @@ -1315,7 +1315,7 @@ class ControlNetFlux(nn.Module): for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) - for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): block_sample = controlnet_block(block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index cc3bcb0e..d82bde91 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -460,7 +460,7 @@ def get_noisy_model_input_and_timesteps( sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents - return noisy_model_input, timesteps, sigmas + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): diff --git a/library/flux_utils.py b/library/flux_utils.py index 7b538d13..4a3817fd 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -157,7 +157,7 @@ def load_controlnet(): # TODO is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("meta"): + with torch.device("cuda:0"): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) # if transformer is not None: # controlnet.load_state_dict(transformer.state_dict(), strict=False) From 35778f021897796410372aed8540547ba317c2a3 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 11:09:05 +0000 Subject: [PATCH 314/356] fix sample_images type --- flux_train_control_net.py | 31 ++++++++++++++----------------- library/flux_train_utils.py | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 205ff6b6..791900d1 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -444,8 +444,7 @@ def train(args): clean_memory_on_device(accelerator.device) - # if args.deepspeed: - if True: + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -644,7 +643,6 @@ def train(args): t5_attn_mask = None with accelerator.autocast(): - print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, @@ -656,8 +654,6 @@ def train(args): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - print("control end") - print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -763,18 +759,19 @@ def train(args): accelerator.wait_for_everyone() optimizer_eval_fn() - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(flux), - ) + # TODO: save cn models + # if args.save_every_n_epochs is not None: + # if accelerator.is_main_process: + # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + # args, + # True, + # accelerator, + # save_dtype, + # epoch, + # num_train_epochs, + # global_step, + # accelerator.unwrap_model(flux), + # ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d82bde91..de2ee030 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,7 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - x = x.float() + # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 4dd4cd6ec8c55fa94b53217181ed9c95e59eed56 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 12:47:01 +0000 Subject: [PATCH 315/356] work cn load and validation --- flux_train_control_net.py | 20 ++++---------------- library/flux_models.py | 6 +++--- library/flux_train_utils.py | 18 ++++++++++++++---- library/flux_utils.py | 25 ++++++++++++++++--------- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 791900d1..cbfac418 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet() + controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -568,7 +568,7 @@ def train(args): # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -718,7 +718,7 @@ def train(args): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) # 指定ステップごとにモデルを保存 @@ -774,7 +774,7 @@ def train(args): # ) flux_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) optimizer_train_fn() @@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - # parser.add_argument( - # "--conditioning_data_dir", - # type=str, - # default=None, - # help="conditioning data directory / 条件付けデータのディレクトリ", - # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index 2fc21db9..4123b40e 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1142,11 +1142,11 @@ class ControlNetFlux(nn.Module): self.num_single_blocks = len(self.single_blocks) # add ControlNet blocks - self.controlnet_blocks_for_double = nn.ModuleList([]) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(controlnet_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) @@ -1312,7 +1312,7 @@ class ControlNetFlux(nn.Module): controlnet_block_samples = () controlnet_single_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2ee030..dbbaba73 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -175,10 +175,6 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") @@ -232,6 +228,12 @@ def sample_image_inference( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) @@ -315,6 +317,8 @@ def denoise( ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() @@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + ) parser.add_argument( "--t5xxl_max_token_length", type=int, diff --git a/library/flux_utils.py b/library/flux_utils.py index 4a3817fd..fb7a3074 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,15 +153,22 @@ def load_ae( return ae -def load_controlnet(): - # TODO - is_schnell = False - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("cuda:0"): - controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - # if transformer is not None: - # controlnet.load_state_dict(transformer.state_dict(), strict=False) - return controlnet +def load_controlnet( + ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet def load_clip_l( From 31ca899b6b5425466c814d0d9e2e4e8bfbf93001 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 13:03:28 +0000 Subject: [PATCH 316/356] fix depth value --- library/flux_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 4123b40e..328ad481 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams, controlnet_depth=2): + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): super().__init__() self.params = params @@ -1128,7 +1128,7 @@ class ControlNetFlux(nn.Module): self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TODO + for _ in range(controlnet_single_depth) ] ) @@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(0): # TODO + for _ in range(controlnet_single_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block) From 2a61fc07846dc919ea64b568f7e18c010e5c8e06 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Wed, 20 Nov 2024 21:20:35 +0900 Subject: [PATCH 317/356] docs: fix typo from block_to_swap to blocks_to_swap in README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 81a3199b..f9c85e3a 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,11 @@ When training LoRA for Text Encoder (without `--network_train_unet_only`), more __Options for GPUs with less VRAM:__ -By specifying `--block_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. +By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. -Specify a number like `--block_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. +Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. -`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--block_to_swap`. +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`. Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: @@ -82,7 +82,7 @@ Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settin The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. -The training can be done with 12GB VRAM GPUs with `--block_to_swap 16` with 8bit AdamW. Please use settings like below: +The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below: ``` --blocks_to_swap 16 From 0b5229a9550cb921b83d22472c4785a15c42ba90 Mon Sep 17 00:00:00 2001 From: minux302 Date: Thu, 21 Nov 2024 15:55:27 +0000 Subject: [PATCH 318/356] save cn --- flux_train_control_net.py | 34 +++++++++++++++------------------- library/flux_train_utils.py | 1 - 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index cbfac418..0f38b709 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -613,9 +613,6 @@ def train(args): text_encoder_conds = text_encoding_strategy.encode_tokens( flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask ) - # if args.full_fp16: - # text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # TODO: check text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps @@ -733,7 +730,7 @@ def train(args): epoch, num_train_epochs, global_step, - accelerator.unwrap_model(flux), + accelerator.unwrap_model(controlnet), ) optimizer_train_fn() @@ -759,19 +756,18 @@ def train(args): accelerator.wait_for_everyone() optimizer_eval_fn() - # TODO: save cn models - # if args.save_every_n_epochs is not None: - # if accelerator.is_main_process: - # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - # args, - # True, - # accelerator, - # save_dtype, - # epoch, - # num_train_epochs, - # global_step, - # accelerator.unwrap_model(flux), - # ) + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet @@ -791,7 +787,7 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) logger.info("model saved.") diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index dbbaba73..5e25c7fe 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -237,7 +237,6 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image From 420a180d938c7b5a6e3006b1719dbfeaae72a2cc Mon Sep 17 00:00:00 2001 From: recris Date: Wed, 27 Nov 2024 18:11:51 +0000 Subject: [PATCH 319/356] Implement pseudo Huber loss for Flux and SD3 --- fine_tune.py | 6 +-- flux_train.py | 2 +- flux_train_network.py | 2 +- library/train_util.py | 74 ++++++++++++++++------------ sd3_train.py | 2 +- sd3_train_network.py | 2 +- sdxl_train.py | 6 +-- sdxl_train_control_net.py | 4 +- sdxl_train_control_net_lllite.py | 4 +- sdxl_train_control_net_lllite_old.py | 6 ++- train_controlnet.py | 6 +-- train_db.py | 4 +- train_network.py | 9 ++-- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 6 ++- 15 files changed, 76 insertions(+), 61 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0090bd19..70959a75 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,7 +380,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -397,7 +397,7 @@ def train(args): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) @@ -411,7 +411,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/flux_train.py b/flux_train.py index a89e2f13..f6e43b27 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def train(args): # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/flux_train_network.py b/flux_train_network.py index 679db62b..04287f39 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -468,7 +468,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/train_util.py b/library/train_util.py index 25cf7640..c204ebd3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( @@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") - - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = torch.exp(-alpha * timesteps) - elif args.huber_schedule == "snr": - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = torch.full((b_size,), args.huber_c) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - huber_c = huber_c.to(device) - elif args.loss_type == "l2": - huber_c = None # may be anything, as it's not used - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - timesteps = timesteps.long().to(device) - return timesteps, huber_c +def get_timesteps(min_timestep, max_timestep, b_size, device): + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + timesteps = timesteps.long() + return timesteps def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): @@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps + + +def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, 'alphas_cumprod'): + raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] + args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler ): - if loss_type == "l2": + if args.loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif loss_type == "l1": + elif args.loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif loss_type == "huber": + elif args.loss_type == "huber": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif loss_type == "smooth_l1": + elif args.loss_type == "smooth_l1": + huber_c = get_huber_threshold(args, timesteps, noise_scheduler) huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5903,7 +5913,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") return loss diff --git a/sd3_train.py b/sd3_train.py index 96ec951b..cf2bdf93 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def train(args): # ) # calculate loss loss = train_util.conditional_loss( - model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/sd3_train_network.py b/sd3_train_network.py index 1726e325..fb7711bd 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -378,7 +378,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, None, weighting + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/sdxl_train.py b/sdxl_train.py index e26f4aa1..1bc27ec6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,7 +695,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -720,7 +720,7 @@ def train(args): ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) @@ -738,7 +738,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 24080afb..d0051d18 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,7 +512,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -534,7 +534,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 2946c97d..66214f5d 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,7 +463,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -485,7 +485,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 2d446523..5e10654b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -406,7 +406,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +426,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index 8c7882c8..da7a08d6 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -464,8 +464,8 @@ def train(args): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c( - args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + timesteps = train_util.get_timesteps( + 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device ) # Add noise to the latents according to the noise magnitude at each timestep @@ -499,7 +499,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) loss = loss.mean([1, 2, 3]) diff --git a/train_db.py b/train_db.py index 51e209f3..a185b31b 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -385,7 +385,7 @@ def train(args): target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_network.py b/train_network.py index bbf381f9..c7d4f5dc 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ class NetworkTrainer: ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -244,7 +244,7 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, huber_c, None + return noise_pred, target, timesteps, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): if args.min_snr_gamma: @@ -806,6 +806,7 @@ class NetworkTrainer: "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), @@ -1193,7 +1194,7 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -1207,7 +1208,7 @@ class NetworkTrainer: ) loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5f4657eb..9e1e57c4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -585,7 +585,7 @@ class TextualInversionTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -602,7 +602,7 @@ class TextualInversionTrainer: target = noise loss = train_util.conditional_loss( - noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 52d525fc..94473360 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -461,7 +461,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,7 +473,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 740ec1d5265fa321659589ae6a75a4a9898ef8be Mon Sep 17 00:00:00 2001 From: recris Date: Thu, 28 Nov 2024 20:38:32 +0000 Subject: [PATCH 320/356] Fix issues found in review --- fine_tune.py | 2 +- library/train_util.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 70959a75..401a40f0 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -411,7 +411,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler ) accelerator.backward(loss) diff --git a/library/train_util.py b/library/train_util.py index c204ebd3..eaf6ec00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5829,8 +5829,8 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep, max_timestep, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - timesteps = timesteps.long() + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + timesteps = timesteps.long().to(device) return timesteps @@ -5875,8 +5875,8 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, 'alphas_cumprod'): - raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.") + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c From 575f583fd9cbaf7f7b644a31437ed9094810b99a Mon Sep 17 00:00:00 2001 From: minux302 Date: Fri, 29 Nov 2024 23:55:52 +0900 Subject: [PATCH 321/356] add README --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index f9c85e3a..2b1ca3f8 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Nov 14, 2024: - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) +- [FLUX.1 ControlNet training](#flux1-controlnet-training) - [FLUX.1 OFT training](#flux1-oft-training) - [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) - [FLUX.1 fine-tuning](#flux1-fine-tuning) @@ -245,6 +246,22 @@ example: If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. +### FLUX.1 ControlNet training +We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options. + +Sample command is below. It will work with 80GB VRAM GPUs. +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors +--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 +--optimizer_type adamw8bit --learning_rate 2e-5 +--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml +--output_dir /path/to/output/dir --output_name flux-cn +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed +``` + + ### FLUX.1 OFT training You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. From be5860f8e266c5562f123fe9e0cb3febef615290 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:08:21 +0900 Subject: [PATCH 322/356] add schnell option to load_cn --- flux_train_control_net.py | 4 ++-- library/flux_utils.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index a17c811e..bb27c35e 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -259,14 +259,14 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: diff --git a/library/flux_utils.py b/library/flux_utils.py index f2759c37..8be1d63e 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,14 +1,14 @@ -from dataclasses import replace import json import os +from dataclasses import replace from typing import List, Optional, Tuple, Union + import einops import torch - -from safetensors.torch import load_file -from safetensors import safe_open from accelerate import init_empty_weights -from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.utils import setup_logging @@ -154,11 +154,9 @@ def load_ae( def load_controlnet( - ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ): logger.info("Building ControlNet") - # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - is_schnell = False name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL with torch.device(device): controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) From f40632bac6704886a7640c327d64820f8f017df8 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 30 Nov 2024 00:15:47 +0900 Subject: [PATCH 323/356] rm abundant arg --- flux_train_network.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 31433536..fa3810e3 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -6,12 +6,21 @@ from typing import Any, Optional import torch from accelerate import Accelerator -from library.device_utils import init_ipex, clean_memory_on_device + +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() -from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util import train_network +from library import ( + flux_models, + flux_train_utils, + flux_utils, + sd3_train_utils, + strategy_base, + strategy_flux, + train_util, +) from library.utils import setup_logging setup_logging() @@ -125,7 +134,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) From 928b9393daac252d0b6c4c9dd277d549b3dad8e9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 20 Nov 2024 11:15:30 -0500 Subject: [PATCH 324/356] Allow unknown schedule-free optimizers to continue to module loader --- library/train_util.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 25cf7640..74050880 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4600,7 +4600,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - + optimizer_type = args.optimizer_type if args.use_8bit_adam: assert ( @@ -4874,6 +4874,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type.endswith("schedulefree".lower()): + should_train_optimizer = True try: import schedulefree as sf except ImportError: @@ -4885,10 +4886,10 @@ def get_optimizer(args, trainable_params): optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop - optimizer.train() + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う @@ -4990,6 +4991,10 @@ def get_optimizer(args, trainable_params): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + if hasattr(optimizer, 'train') and callable(optimizer.train): + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + return optimizer_name, optimizer_args, optimizer From 87f5224e2d19254748158939cbca75802fc024f2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 20 Nov 2024 11:57:15 -0500 Subject: [PATCH 325/356] Support d*lr for ProdigyPlus optimizer --- train_network.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index bbf381f9..65962bd7 100644 --- a/train_network.py +++ b/train_network.py @@ -61,6 +61,7 @@ class NetworkTrainer: avr_loss, lr_scheduler, lr_descriptions, + optimizer=None, keys_scaled=None, mean_norm=None, maximum_norm=None, @@ -93,6 +94,30 @@ class NetworkTrainer: logs[f"lr/d*lr/{lr_desc}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) return logs @@ -1279,7 +1304,7 @@ class NetworkTrainer: if len(accelerator.trackers) > 0: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=global_step) From 6593cfbec14c0be70407b5d6d85d569ecf8160f1 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 21 Nov 2024 14:41:37 -0500 Subject: [PATCH 326/356] Fix d * lr step log --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 65962bd7..c236a2c9 100644 --- a/train_network.py +++ b/train_network.py @@ -116,7 +116,7 @@ class NetworkTrainer: args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): logs[f"lr/d*lr/group{i}"] = ( - optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) return logs From c7cadbc8c73b48eaacbfb44b18121d20df373e19 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 15:52:03 -0500 Subject: [PATCH 327/356] Add pytest testing --- .github/workflows/tests.yml | 54 +++++++++++++ library/train_util.py | 4 +- pytest.ini | 7 ++ tests/test_optimizer.py | 153 ++++++++++++++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 pytest.ini create mode 100644 tests/test_optimizer.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..50b08243 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,54 @@ + +name: Python package + +on: [push] + +jobs: + build: + + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools wheel + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + + - name: Upload pytest test results + uses: actions/upload-artifact@v4 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + # Use always() to always run this step to publish test results when there are test failures + if: ${{ always() }} diff --git a/library/train_util.py b/library/train_util.py index 25cf7640..823cd366 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -21,7 +21,7 @@ from typing import ( Optional, Sequence, Tuple, - Union, + Union ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -4598,7 +4598,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): +def get_optimizer(args, trainable_params) -> tuple[str, str, object]: # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..63e03efc --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +minversion = 6.0 +testpaths = + tests +filterwarnings = + ignore::DeprecationWarning + ignore::UserWarning diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 00000000..f6ade91a --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,153 @@ +from unittest.mock import patch +from library.train_util import get_optimizer +from train_network import setup_parser +import torch +from torch.nn import Parameter + +# Optimizer libraries +import bitsandbytes as bnb +from lion_pytorch import lion_pytorch +import schedulefree + +import dadaptation +import dadaptation.experimental as dadapt_experimental + +import prodigyopt +import schedulefree as sf +import transformers + + +def test_default_get_optimizer(): + with patch("sys.argv", [""]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "torch.optim.adamw.AdamW" + assert optimizer_args == "" + assert isinstance(optimizer, torch.optim.AdamW) + + +def test_get_schedulefree_optimizer(): + with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree" + assert optimizer_args == "" + assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree) + + +def test_all_supported_optimizers(): + optimizers = [ + { + "name": "bitsandbytes.optim.adamw.AdamW8bit", + "alias": "AdamW8bit", + "instance": bnb.optim.AdamW8bit, + }, + { + "name": "lion_pytorch.lion_pytorch.Lion", + "alias": "Lion", + "instance": lion_pytorch.Lion, + }, + { + "name": "torch.optim.adamw.AdamW", + "alias": "AdamW", + "instance": torch.optim.AdamW, + }, + { + "name": "bitsandbytes.optim.lion.Lion8bit", + "alias": "Lion8bit", + "instance": bnb.optim.Lion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW8bit", + "alias": "PagedAdamW8bit", + "instance": bnb.optim.PagedAdamW8bit, + }, + { + "name": "bitsandbytes.optim.lion.PagedLion8bit", + "alias": "PagedLion8bit", + "instance": bnb.optim.PagedLion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW", + "alias": "PagedAdamW", + "instance": bnb.optim.PagedAdamW, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW32bit", + "alias": "PagedAdamW32bit", + "instance": bnb.optim.PagedAdamW32bit, + }, + {"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD}, + { + "name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint", + "alias": "DAdaptAdamPreprint", + "instance": dadapt_experimental.DAdaptAdamPreprint, + }, + { + "name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad", + "alias": "DAdaptAdaGrad", + "instance": dadaptation.DAdaptAdaGrad, + }, + { + "name": "dadaptation.dadapt_adan.DAdaptAdan", + "alias": "DAdaptAdan", + "instance": dadaptation.DAdaptAdan, + }, + { + "name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP", + "alias": "DAdaptAdanIP", + "instance": dadapt_experimental.DAdaptAdanIP, + }, + { + "name": "dadaptation.dadapt_lion.DAdaptLion", + "alias": "DAdaptLion", + "instance": dadaptation.DAdaptLion, + }, + { + "name": "dadaptation.dadapt_sgd.DAdaptSGD", + "alias": "DAdaptSGD", + "instance": dadaptation.DAdaptSGD, + }, + { + "name": "prodigyopt.prodigy.Prodigy", + "alias": "Prodigy", + "instance": prodigyopt.Prodigy, + }, + { + "name": "transformers.optimization.Adafactor", + "alias": "Adafactor", + "instance": transformers.optimization.Adafactor, + }, + { + "name": "schedulefree.adamw_schedulefree.AdamWScheduleFree", + "alias": "AdamWScheduleFree", + "instance": sf.AdamWScheduleFree, + }, + { + "name": "schedulefree.sgd_schedulefree.SGDScheduleFree", + "alias": "SGDScheduleFree", + "instance": sf.SGDScheduleFree, + }, + ] + + for opt in optimizers: + with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, _, optimizer = get_optimizer(args, [param]) + assert optimizer_name == opt.get("name") + + instance = opt.get("instance") + assert instance is not None + assert isinstance(optimizer, instance) From 2dd063a679effae2538c474fece1e7aacad0c9c5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 15:57:31 -0500 Subject: [PATCH 328/356] add torch torchvision accelerate versions --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 50b08243..96ab612d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 - name: Test with pytest run: | pip install pytest pytest-cov From e59e276fb948a1dc8a64672d8fd6d3a7eb166c80 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:03:29 -0500 Subject: [PATCH 329/356] Add dadaptation --- .github/workflows/tests.yml | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 96ab612d..433c326b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.10"] steps: - uses: actions/checkout@v4 @@ -26,30 +26,14 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.x' + cache: 'pip' # caching pip dependencies - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 - name: Test with pytest run: | - pip install pytest pytest-cov - pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + pip install pytest + pytest - - name: Upload pytest test results - uses: actions/upload-artifact@v4 - with: - name: pytest-results-${{ matrix.python-version }} - path: junit/test-results-${{ matrix.python-version }}.xml - # Use always() to always run this step to publish test results when there are test failures - if: ${{ always() }} From dd3b846b54814b605bd33ae08ed480ea5075483b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:18:05 -0500 Subject: [PATCH 330/356] Install pytorch first to pin version --- .github/workflows/tests.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 433c326b..9ae67b0e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,6 +18,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.x' + - name: Install dependencies run: python -m pip install --upgrade pip setuptools wheel @@ -27,11 +28,13 @@ jobs: with: python-version: '3.x' cache: 'pip' # caching pip dependencies + - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + pip install -r requirements.txt + - name: Test with pytest run: | pip install pytest From 89825d6898ba6629b18cc8c1f9fbd93a730ff36e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:27:13 -0500 Subject: [PATCH 331/356] Run typos workflows once where appropriate --- .github/workflows/typos.yml | 6 ++++-- pytest.ini | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 0149dcdd..667146a7 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -1,9 +1,11 @@ --- -# yamllint disable rule:line-length name: Typos -on: # yamllint disable-line rule:truthy +on: push: + branches: + - main + - dev pull_request: types: - opened diff --git a/pytest.ini b/pytest.ini index 63e03efc..484d3aef 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,3 +5,4 @@ testpaths = filterwarnings = ignore::DeprecationWarning ignore::UserWarning + ignore::FutureWarning From 4f7f248071c93f539c12c8a35380b6d983bfff4c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 29 Nov 2024 16:28:51 -0500 Subject: [PATCH 332/356] Bump typos action --- .github/workflows/typos.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 667146a7..87ebdf89 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -20,4 +20,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.24.3 + uses: crate-ci/typos@v1.28.1 From 9c885e549dbb5535b37f2a3220b5a8f53ad4d211 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 30 Nov 2024 18:25:50 +0900 Subject: [PATCH 333/356] fix: improve pos_embed handling for oversized images and update resolution_area_to_latent_size, when sample image size > train image size --- library/sd3_models.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index 8b90205d..2f3c82ee 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1017,22 +1017,35 @@ class MMDiT(nn.Module): patched_size = patched_size_ break if patched_size is None: - raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] pos_embed = self.resolution_pos_embeds[patched_size] - pos_embed_size = round(math.sqrt(pos_embed.shape[1])) + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO if h > pos_embed_size or w > pos_embed_size: # # fallback to normal pos_embed # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) # extend pos_embed size logger.warning( - f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - pos_embed_size = max(h, w) - pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) self.resolution_pos_embeds[patched_size] = pos_embed - logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) if not random_crop: top = (pos_embed_size - h) // 2 From 7b61e9eb58e0a004b451e8f06c9f90b861f81b45 Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 30 Nov 2024 11:36:40 +0000 Subject: [PATCH 334/356] Fix issues found in review (pt 2) --- library/train_util.py | 2 +- sd3_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index eaf6ec00..d5e72323 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5875,7 +5875,7 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): + if noise_scheduler is None or not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 diff --git a/sd3_train.py b/sd3_train.py index cf2bdf93..909c5ead 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -845,7 +845,7 @@ def train(args): # ) # calculate loss loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", None ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) From 14f642f88be888ce1a4157b550186347c159ca42 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 13:30:35 +0900 Subject: [PATCH 335/356] fix: huber_schedule exponential not working on sd3_train.py --- library/train_util.py | 2 +- sd3_train.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d5e72323..eaf6ec00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5875,7 +5875,7 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if noise_scheduler is None or not hasattr(noise_scheduler, "alphas_cumprod"): + if not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 diff --git a/sd3_train.py b/sd3_train.py index 909c5ead..73a68aa6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -675,8 +675,8 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) - # noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # only used to get timesteps, etc. TODO manage timesteps etc. separately + dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) if accelerator.is_main_process: init_kwargs = {} @@ -844,9 +844,7 @@ def train(args): # 1, # ) # calculate loss - loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", None - ) + loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 0fe6320f09a61859c3faa134affb810cb42b62cd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 14:13:37 +0900 Subject: [PATCH 336/356] fix flux_train.py is not working --- flux_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index f6e43b27..cfe14885 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def train(args): # calculate loss loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting From cc11989755d0dd61f10eeec85983c751fd7ebb47 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:20:28 +0900 Subject: [PATCH 337/356] fix: refactor huber-loss calculation in multiple training scripts --- fine_tune.py | 13 ++++--------- flux_train.py | 5 ++--- library/train_util.py | 21 +++++++++++---------- sd3_train.py | 3 ++- sdxl_train.py | 13 ++++--------- sdxl_train_control_net.py | 9 +++------ sdxl_train_control_net_lllite.py | 9 +++------ sdxl_train_control_net_lllite_old.py | 10 ++++++---- train_controlnet.py | 11 +++++------ train_db.py | 9 +++------ train_network.py | 5 ++--- train_textual_inversion.py | 5 ++--- train_textual_inversion_XTI.py | 9 +++++---- 13 files changed, 52 insertions(+), 70 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 401a40f0..17608706 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -380,9 +380,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -394,11 +392,10 @@ def train(args): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -410,9 +407,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "mean", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/flux_train.py b/flux_train.py index cfe14885..fced3bef 100644 --- a/flux_train.py +++ b/flux_train.py @@ -666,9 +666,8 @@ def train(args): target = noise - latents # calculate loss - loss = train_util.conditional_loss( - args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/library/train_util.py b/library/train_util.py index eaf6ec00..fe74ddc7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5869,7 +5869,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps -def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + b_size = timesteps.shape[0] if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps @@ -5890,22 +5893,20 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch def conditional_loss( - args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): - if args.loss_type == "l2": + if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif args.loss_type == "l1": + elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) - elif args.loss_type == "huber": - huber_c = get_huber_threshold(args, timesteps, noise_scheduler) + elif loss_type == "huber": huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif args.loss_type == "smooth_l1": - huber_c = get_huber_threshold(args, timesteps, noise_scheduler) + elif loss_type == "smooth_l1": huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -5913,7 +5914,7 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") return loss @@ -5923,7 +5924,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") - names.append("text_encoder3") # SD3 + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) diff --git a/sd3_train.py b/sd3_train.py index 73a68aa6..120455e7 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -844,7 +844,8 @@ def train(args): # 1, # ) # calculate loss - loss = train_util.conditional_loss(args, model_pred.float(), target.float(), timesteps, "none", dummy_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/sdxl_train.py b/sdxl_train.py index 1bc27ec6..b9d52924 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -695,9 +695,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -711,6 +709,7 @@ def train(args): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred @@ -719,9 +718,7 @@ def train(args): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -737,9 +734,7 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index d0051d18..01387409 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -512,9 +512,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) @@ -533,9 +531,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 66214f5d..365059b7 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -463,9 +463,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -484,9 +482,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5e10654b..5b372bef 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -324,7 +325,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -426,9 +429,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_controlnet.py b/train_controlnet.py index da7a08d6..177d2b11 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -307,10 +307,12 @@ def train(args): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -464,9 +466,7 @@ def train(args): ) # Sample a random timestep for each image - timesteps = train_util.get_timesteps( - 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device - ) + timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -498,9 +498,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_db.py b/train_db.py index a185b31b..ad21f8d1 100644 --- a/train_db.py +++ b/train_db.py @@ -370,9 +370,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -384,9 +382,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index c7d4f5dc..0b420818 100644 --- a/train_network.py +++ b/train_network.py @@ -1207,9 +1207,8 @@ class NetworkTrainer: train_unet, ) - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9e1e57c4..65da4859 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -601,9 +601,8 @@ class TextualInversionTrainer: else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 94473360..2a2b4231 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -473,9 +475,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler - ) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) From 14760407871c7eaa26210c7db71ce2740a817c4c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:26:39 +0900 Subject: [PATCH 338/356] fix: update help text for huber loss parameters in train_util.py --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index fe74ddc7..a40983a6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3905,14 +3905,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) parser.add_argument( "--huber_scale", type=float, default=1.0, - help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", ) parser.add_argument( From 34e7f509c41491f9a08c16c8ead2adf5cb210ec1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 21:36:24 +0900 Subject: [PATCH 339/356] docs: update README for huber loss --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index f9c85e3a..89a96827 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates +1 Dec, 2024: + +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! + - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. From 1dc873d9b463d50e27ae8572c28a473ce9a1254f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Dec 2024 22:00:44 +0900 Subject: [PATCH 340/356] update README and clean up code for schedulefree optimizer --- README.md | 4 +++- library/train_util.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 89a96827..8db5c4d4 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,11 @@ The command to install PyTorch is as follows: 1 Dec, 2024: -- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. +- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO! + Nov 14, 2024: - Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. diff --git a/library/train_util.py b/library/train_util.py index 289ab823..6cfd14d5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4609,7 +4609,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - + optimizer_type = args.optimizer_type if args.use_8bit_adam: assert ( @@ -4883,7 +4883,6 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type.endswith("schedulefree".lower()): - should_train_optimizer = True try: import schedulefree as sf except ImportError: @@ -5000,8 +4999,8 @@ def get_optimizer(args, trainable_params): optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - if hasattr(optimizer, 'train') and callable(optimizer.train): - # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + if hasattr(optimizer, "train") and callable(optimizer.train): + # make optimizer as train mode before training for schedulefree optimizer. the optimizer will be in eval mode in sampling and saving. optimizer.train() return optimizer_name, optimizer_args, optimizer From e369b9a252b90d1f57ea20dd6f5d05ec0c287ae1 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 2 Dec 2024 23:38:54 +0900 Subject: [PATCH 341/356] docs: update README with FLUX.1 ControlNet training details and improve argument help text --- README.md | 10 +++++++++- library/flux_train_utils.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 45e3cb7a..6a5cdd34 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,15 @@ The command to install PyTorch is as follows: ### Recent Updates -1 Dec, 2024: +Dec 2, 2024: + +- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. + - Not fully tested. Feedback is welcome. + - 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution. + - Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required. + - Multi-GPU training is not tested. + +Dec 1, 2024: - Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5e25c7fe..de2e2b48 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -567,7 +567,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" ) parser.add_argument( "--t5xxl_max_token_length", From 5ab00f9b49b5a3958bb0267fdb9236a96d503dbd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:39:51 -0500 Subject: [PATCH 342/356] Update workflow tests with cleanup and documentation --- .github/workflows/tests.yml | 46 +++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ae67b0e..5a790d57 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,42 +1,44 @@ +name: Test with pytet -name: Python package - -on: [push] +on: + push: + branches: + - main + - dev + - sd3 + pull_request: + branches: + - main + - dev + - sd3 jobs: build: - runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.10"] # Python versions to test + pytorch-version: ["2.4.0"] # PyTorch versions to test steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - uses: actions/setup-python@v5 with: - python-version: '3.x' + python-version: ${{ matrix.python-version }} + cache: 'pip' - - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel - - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - cache: 'pip' # caching pip dependencies + - name: Install and update pip, setuptools, wheel + run: | + # Setuptools, wheel for compiling some packages + python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 pip install -r requirements.txt - name: Test with pytest - run: | - pip install pytest - pytest + run: pytest # See pytest.ini for configuration From 63738ecb0758a02555392d2c283a83bba1c6f98e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:48:30 -0500 Subject: [PATCH 343/356] Add tests documentation --- tests/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/README.md diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..19eeab0e --- /dev/null +++ b/tests/README.md @@ -0,0 +1,32 @@ +# Tests + +## Install + +``` +pip install pytest +``` + +## Usage + +``` +pytest +``` + +## Contribution + +Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests. + +Tests are functions starting with `test_` and files with the pattern `test_*.py`. + +``` +def test_x(): + assert 1 == 2, "Invalid test response" +``` + +## Resources + +- https://circleci.com/blog/testing-pytorch-model-with-pytest/ +- https://pytorch.org/docs/stable/testing.html +- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +- https://github.com/huggingface/pytorch-image-models/tree/main/tests +- https://github.com/pytorch/pytorch/tree/main/test From 2610e96e9e3d0605d5a16615efa26ae8935ed3aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:49:58 -0500 Subject: [PATCH 344/356] Pytest --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5a790d57..672a657b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Test with pytet +name: Test with pytest on: push: From 3e5d89c76c287872e20c4a967d36b51384285be8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:51:57 -0500 Subject: [PATCH 345/356] Add more resources --- tests/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/README.md b/tests/README.md index 19eeab0e..9836da8b 100644 --- a/tests/README.md +++ b/tests/README.md @@ -25,8 +25,17 @@ def test_x(): ## Resources +### pytest + +- https://docs.pytest.org/en/stable/index.html +- https://docs.pytest.org/en/stable/how-to/assert.html +- https://docs.pytest.org/en/stable/how-to/doctest.html + +### PyTorch testing + - https://circleci.com/blog/testing-pytorch-model-with-pytest/ - https://pytorch.org/docs/stable/testing.html - https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests - https://github.com/huggingface/pytorch-image-models/tree/main/tests - https://github.com/pytorch/pytorch/tree/main/test + From 8b36d907d8635dca64224574b5cb15013e00809d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 3 Dec 2024 08:43:26 +0900 Subject: [PATCH 346/356] feat: support block_to_swap for FLUX.1 ControlNet training --- README.md | 13 +++++++++++ flux_train_control_net.py | 46 +++++++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6a5cdd34..f0272519 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ The command to install PyTorch is as follows: ### Recent Updates + +Dec 3, 2024: + +-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training). + Dec 2, 2024: - FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. @@ -276,6 +281,14 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_tr --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed ``` +For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested. +``` + --blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk +``` + +The training can be done with 16GB VRAM GPUs with around 30 blocks swapped. + +`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used. ### FLUX.1 OFT training diff --git a/flux_train_control_net.py b/flux_train_control_net.py index bb27c35e..5548fd99 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -119,9 +119,7 @@ def train(args): "datasets": [ { "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension + args.train_data_dir, args.conditioning_data_dir, args.caption_extension ) } ] @@ -263,13 +261,17 @@ def train(args): args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors ) flux.requires_grad_(False) - flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet(args.controlnet, is_schnell, torch.float32, accelerator.device, args.disable_mmap_load_safetensors) + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) controlnet.train() if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) # block swap @@ -296,7 +298,11 @@ def train(args): # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") flux.enable_block_swap(args.blocks_to_swap, accelerator.device) - controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) if not cache_latents: # load VAE here if not cached @@ -455,9 +461,7 @@ def train(args): else: # accelerator does some magic # if we doesn't swap blocks, we can move the model to device - controlnet = accelerator.prepare(controlnet, device_placement=[not is_swapping_blocks]) - if is_swapping_blocks: - accelerator.unwrap_model(controlnet).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -564,11 +568,13 @@ def train(args): ) if is_swapping_blocks: - accelerator.unwrap_model(controlnet).prepare_block_swap_before_forward() + flux.prepare_block_swap_before_forward() # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -629,7 +635,11 @@ def train(args): # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype) + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) # get guidance: ensure args.guidance_scale is float guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) @@ -638,7 +648,7 @@ def train(args): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, @@ -715,7 +725,15 @@ def train(args): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, ) # 指定ステップごとにモデルを保存 From 6bee18db4fbf62ebd2a1da88a5851c48f2e06c54 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 15:12:27 +0900 Subject: [PATCH 347/356] fix: resolve model corruption issue with pos_embed when using --enable_scaled_pos_embed --- README.md | 2 ++ library/sd3_models.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f0272519..6162359d 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 7, 2024: +- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/library/sd3_models.py b/library/sd3_models.py index 2f3c82ee..e4a93186 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -870,8 +870,10 @@ class MMDiT(nn.Module): self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # remove pos_embed to free up memory up to 0.4 GB - self.pos_embed = None + # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved + # self.pos_embed = None + # move pos_embed to CPU to free up memory up to 0.4 GB + self.pos_embed = self.pos_embed.cpu() # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes)) From abff4b0ec7bb37b338924e38392593f2bea2b8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 7 Dec 2024 16:12:46 +0800 Subject: [PATCH 348/356] Unify controlnet parameters name and change scripts name. (#1821) * Update sd3_train.py * add freeze block lr * Update train_util.py * update * Revert "add freeze block lr" This reverts commit 8b1653548f8f219e5be2cde96f65a8813cf9ea1f. # Conflicts: # library/train_util.py # sd3_train.py * use same control net model path * use controlnet_model_name_or_path --- flux_train_control_net.py | 2 +- library/flux_train_utils.py | 2 +- sdxl_train_control_net.py | 8 ++++---- train_controlnet.py => train_control_net.py | 0 4 files changed, 6 insertions(+), 6 deletions(-) rename train_controlnet.py => train_control_net.py (100%) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 5548fd99..9d36a41d 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -265,7 +265,7 @@ def train(args): # load controlnet controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype controlnet = flux_utils.load_controlnet( - args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors ) controlnet.train() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2e2b48..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") parser.add_argument( - "--controlnet", + "--controlnet_model_name_or_path", type=str, default=None, help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 01387409..ffbf03ca 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -184,12 +184,12 @@ def train(args): # make control net logger.info("make ControlNet") - if args.controlnet_model_path: + if args.controlnet_model_name_or_path: with init_empty_weights(): control_net = SdxlControlNet() - logger.info(f"load ControlNet from {args.controlnet_model_path}") - filename = args.controlnet_model_path + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) else: @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser: sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", diff --git a/train_controlnet.py b/train_control_net.py similarity index 100% rename from train_controlnet.py rename to train_control_net.py From e425996a5953f0479384e70b6490e751c2d00b1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 17:28:19 +0900 Subject: [PATCH 349/356] feat: unify ControlNet model name option and deprecate old training script --- README.md | 7 +++++++ train_controlnet.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 train_controlnet.py diff --git a/README.md b/README.md index 6162359d..67836ddf 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,13 @@ The command to install PyTorch is as follows: ### Recent Updates Dec 7, 2024: + +- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! + + - Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 00000000..365e35c8 --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,23 @@ +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +from library import train_util +from train_control_net import setup_parser, train + +if __name__ == "__main__": + logger.warning( + "The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead" + " / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。" + ) + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) From 3cb8cb2d4fd697a49135193ac0873204e0139e62 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 9 Dec 2024 15:20:04 -0500 Subject: [PATCH 350/356] Prevent git credentials from leaking into other actions --- .github/workflows/tests.yml | 4 ++++ .github/workflows/typos.yml | 3 +++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 672a657b..2eddedc7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,6 +23,10 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 87ebdf89..f53cda21 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,6 +18,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false - name: typos-action uses: crate-ci/typos@v1.28.1 From 8e378cf03df645cef897a342559dc5fa7f66a35d Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Wed, 11 Dec 2024 19:43:44 +0900 Subject: [PATCH 351/356] add RAdamScheduleFree support --- library/train_util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a35388fe..72b5b24d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4887,7 +4887,11 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - if optimizer_type == "AdamWScheduleFree".lower(): + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): optimizer_class = sf.AdamWScheduleFree logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") elif optimizer_type == "SGDScheduleFree".lower(): From e89653975ddf429cdf0c0fd268da0a5a3e8dba1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 15 Dec 2024 19:39:47 +0900 Subject: [PATCH 352/356] update requirements.txt and README to include RAdamScheduleFree optimizer support --- README.md | 6 ++++++ requirements.txt | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 67836ddf..bfb22bcf 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,12 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 15, 2024: + +- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu! + - Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`. + - Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler. + Dec 7, 2024: - The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! diff --git a/requirements.txt b/requirements.txt index 0dd1c69c..e0091749 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ pytorch-lightning==1.9.0 bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 -schedulefree==1.2.7 +schedulefree==1.4 tensorboard safetensors==0.4.4 # gradio==3.16.2 From 05bb9183fae18c62a1730fe5060f80c0b99a21f3 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 16:47:59 +0800 Subject: [PATCH 353/356] Add Validation loss for LoRA training --- library/config_util.py | 78 +++++++++++++++++++++++- library/train_util.py | 54 ++++++++++++++++- train_network.py | 131 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 6 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 12d0be17..a57cd36f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -73,6 +73,8 @@ class BaseSubsetParams: token_warmup_min: int = 1 token_warmup_step: float = 0 custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -102,6 +104,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass @@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + # print info info = "" for i, dataset in enumerate(datasets): @@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(f"{info}") + if len(val_datasets) > 0: + info = "" + + for i, dataset in enumerate(val_datasets): + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Validation Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + """ + ), + " ", + ) + + logger.info(f"{info}") + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..a3fa98e9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -145,6 +145,17 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" +def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]: + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -397,6 +408,8 @@ class BaseSubset: token_warmup_min: int, token_warmup_step: Union[float, int], custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -424,6 +437,9 @@ class BaseSubset: self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -453,6 +469,8 @@ class DreamBoothSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -478,6 +496,8 @@ class DreamBoothSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -518,6 +538,8 @@ class FineTuningSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -543,6 +565,8 @@ class FineTuningSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -579,6 +603,8 @@ class ControlNetSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -604,6 +630,8 @@ class ControlNetSubset(BaseSubset): token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -1799,6 +1827,9 @@ class DreamBoothDataset(BaseDataset): bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1808,6 +1839,9 @@ class DreamBoothDataset(BaseDataset): self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_train = is_train + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1992,6 +2026,9 @@ class DreamBoothDataset(BaseDataset): ) continue + if self.is_train == False: + img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed) + if subset.is_reg: num_reg_images += subset.num_repeats * len(img_paths) else: @@ -2009,7 +2046,11 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + if self.is_train: + logger.info(f"{num_train_images} train images with repeating.") + else: + logger.info(f"{num_train_images} validation images with repeating.") + self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") @@ -2050,6 +2091,9 @@ class FineTuningDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2276,6 +2320,9 @@ class ControlNetDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: float, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2324,6 +2371,9 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale, 1.0, debug_dataset, + is_train, + validation_seed, + validation_split, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - + if optimizer_type == "RAdamScheduleFree".lower(): optimizer_class = sf.RAdamScheduleFree logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") diff --git a/train_network.py b/train_network.py index 5e82b307..776feaf7 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import json from multiprocessing import Value from typing import Any, List import toml +import itertools from tqdm import tqdm @@ -114,7 +115,7 @@ class NetworkTrainer: ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): + ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) @@ -373,10 +374,11 @@ class NetworkTrainer: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -398,6 +400,11 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する @@ -444,6 +451,8 @@ class NetworkTrainer: vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -459,6 +468,8 @@ class NetworkTrainer: if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -567,6 +578,8 @@ class NetworkTrainer: # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -580,6 +593,17 @@ class NetworkTrainer: persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + batch_size=1, + shuffle=False, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + cyclic_val_dataloader = itertools.cycle(val_dataloader) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -592,6 +616,10 @@ class NetworkTrainer: # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # Not for sure here. + # if val_dataset_group is not None: + # val_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1064,7 +1092,11 @@ class NetworkTrainer: ) loss_recorder = train_util.LossRecorder() + # val_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -1308,6 +1340,77 @@ class NetworkTrainer: ) accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps): + accelerator.print("\nValidating バリデーション処理...") + + total_loss = 0.0 + + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + batch = next(cyclic_val_dataloader) + + timesteps_list = [10, 350, 500, 650, 990] + + val_loss = 0.0 + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") + timesteps = timesteps.long().to(latents.device) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(False), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + val_loss += loss / len(timesteps_list) + + total_loss += val_loss.detach().item() + + current_val_loss = total_loss / validation_steps + # val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss) + + if len(accelerator.trackers) > 0: + logs = {"loss/current_val_loss": current_val_loss} + accelerator.log(logs, step=global_step) + + # avr_loss: float = val_loss_recorder.moving_average + # logs = {"loss/average_val_loss": avr_loss} + # accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed / 検証シード" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 62164e57925125ed6268983ffa441f1ffecc0e6d Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 17:28:05 +0800 Subject: [PATCH 354/356] Change val loss calculate method --- train_network.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index 776feaf7..5fd1b212 100644 --- a/train_network.py +++ b/train_network.py @@ -1383,16 +1383,20 @@ class NetworkTrainer: else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + # if weighting is not None: + # loss = loss * weighting + # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + # loss = apply_masked_loss(loss, batch) + # loss = loss.mean([1, 2, 3]) # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 64bd5317dc9cb39d69ab7728f36b03157c9b341f Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:42:15 +0800 Subject: [PATCH 355/356] Split val latents/batch and pick up val latents shape size which equal to training batch. --- train_network.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/train_network.py b/train_network.py index 5fd1b212..6bce9e96 100644 --- a/train_network.py +++ b/train_network.py @@ -1349,7 +1349,27 @@ class NetworkTrainer: with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): - batch = next(cyclic_val_dataloader) + + while True: + val_batch = next(cyclic_val_dataloader) + + if "latents" in val_batch and val_batch["latents"] is not None: + val_latents = val_batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + val_latents = self.encode_images_to_latents(args, accelerator, vae, val_batch["images"].to(vae_dtype)) + val_latents = val_latents.to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(val_latents)): + accelerator.print("NaN found in validation latents, replacing with zeros") + val_latents = torch.nan_to_num(val_latents, 0, out=val_latents) + + val_latents = self.shift_scale_latents(args, val_latents) + + if val_latents.shape == latents.shape: + break timesteps_list = [10, 350, 500, 650, 990] @@ -1357,13 +1377,13 @@ class NetworkTrainer: for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + noise = torch.randn_like(val_latents, device=val_latents.device) + b_size = val_latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(latents.device) + timesteps = timesteps.long().to(val_latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1373,27 +1393,16 @@ class NetworkTrainer: noisy_latents.requires_grad_(False), timesteps, text_encoder_conds, - batch, + val_batch, weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(val_latents, noise, timesteps) else: target = noise - # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - # if weighting is not None: - # loss = loss * weighting - # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - # loss = apply_masked_loss(loss, batch) - # loss = loss.mean([1, 2, 3]) - - # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From cb89e0284e1a25b41401861107159e6b943ee387 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:57:04 +0800 Subject: [PATCH 356/356] Change val latent loss compare --- train_network.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 6bce9e96..7276d5dc 100644 --- a/train_network.py +++ b/train_network.py @@ -1350,6 +1350,8 @@ class NetworkTrainer: validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + val_latents = None + while True: val_batch = next(cyclic_val_dataloader) @@ -1371,19 +1373,22 @@ class NetworkTrainer: if val_latents.shape == latents.shape: break + if val_latents is not None: + del val_latents + timesteps_list = [10, 350, 500, 650, 990] val_loss = 0.0 for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(val_latents, device=val_latents.device) - b_size = val_latents.shape[0] + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(val_latents.device) + timesteps = timesteps.long().to(latents.device) - noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1399,7 +1404,7 @@ class NetworkTrainer: if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(val_latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise