From 1e11b7bd096eb3945f7e0b1ffe9a62b1c1028b33 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Sun, 3 Mar 2024 07:30:58 +0100 Subject: [PATCH 1/8] implement token downsampling optimization --- library/sdxl_train_util.py | 5 ++ library/token_merging.py | 106 +++++++++++++++++++++++++++++++++++++ library/train_util.py | 16 ++++++ 3 files changed, 127 insertions(+) create mode 100644 library/token_merging.py diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1932bf88..e5a69db5 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,6 +12,7 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from library import token_merging from .utils import setup_logging setup_logging() import logging @@ -57,6 +58,10 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # apply token merging patch + if args.todo_factor: + token_merging.patch_attention(unet, args) + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/token_merging.py b/library/token_merging.py new file mode 100644 index 00000000..1f53cda8 --- /dev/null +++ b/library/token_merging.py @@ -0,0 +1,106 @@ +# based on: +# https://github.com/ethansmith2000/ImprovedTokenMerge +# https://github.com/ethansmith2000/comfy-todo (MIT) + +import ast +import math + +import torch +import torch.nn.functional as F + +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + + +def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method): + batch_size = item.shape[0] + + item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2) + item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1) + item = item.reshape(batch_size, new_h * new_w, -1) + + return item + + +def compute_merge(x: torch.Tensor, tome_info: dict): + original_h, original_w = tome_info["size"] + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + cur_h = original_h // downsample + cur_w = original_w // downsample + downsample_factor_1 = tome_info["args"]["downsample_factor_depth_1"] + downsample_factor_2 = tome_info["args"]["downsample_factor_depth_2"] + + merge_op = lambda x: x + if downsample == 1 and downsample_factor_1 > 1: + new_h = int(cur_h / downsample_factor_1) + new_w = int(cur_w / downsample_factor_1) + merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, tome_info["args"]["downsample_method"]) + elif downsample == 2 and downsample_factor_2 > 1: + new_h = int(cur_h / downsample_factor_2) + new_w = int(cur_w / downsample_factor_2) + merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, tome_info["args"]["downsample_method"]) + + return merge_op + + +def hook_tome_model(model: torch.nn.Module): + """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ + def hook(module, args): + module._tome_info["size"] = (args[0].shape[2], args[0].shape[3]) + return None + + model._tome_info["hooks"].append(model.register_forward_pre_hook(hook)) + + +def hook_attention(attn: torch.nn.Module): + """ Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """ + def hook(module, args, kwargs): + hidden_states = args[0] + m = compute_merge(hidden_states, module._tome_info) + kwargs["context"] = m(hidden_states) + return args, kwargs + + attn._tome_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True)) + + +def patch_attention(unet: torch.nn.Module, args): + """ Patches the UNet's transformer blocks to apply token downsampling. """ + is_sdxl = isinstance(unet, SdxlUNet2DConditionModel) + todo_kwargs = { + "downsample_factor_depth_1": args.todo_factor, + "downsample_factor_depth_2": args.todo_factor if is_sdxl else 1, # SDXL doesn't have depth 1, so downsample here + "downsample_method": "nearest-exact", + } + if args.todo_args: + for arg in args.todo_args: + key, value = arg.split("=") + todo_kwargs[key] = ast.literal_eval(value) + logger.info(f"enable token downsampling optimization | {todo_kwargs}") + + unet._tome_info = { + "size": None, + "hooks": [], + "args": todo_kwargs, + } + hook_tome_model(unet) + + for _, module in unet.named_modules(): + if module.__class__.__name__ == "BasicTransformerBlock": + module.attn1._tome_info = unet._tome_info + hook_attention(module.attn1) + + return unet + + +def remove_patch(unet: torch.nn.Module): + if hasattr(unet, "_tome_info"): + for hook in unet._tome_info["hooks"]: + hook.remove() + unet._tome_info["hooks"].clear() + + return unet diff --git a/library/train_util.py b/library/train_util.py index b71e4edc..18cb1a21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -69,6 +69,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +from library import token_merging from library.utils import setup_logging setup_logging() @@ -3135,6 +3136,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", ) + parser.add_argument( + "--todo_factor", + type=float, + help="token downsampling (ToDo) factor > 1 (recommend starting with 2)", + ) + parser.add_argument( + "--todo_args", + type=str, + nargs="*", + help='additional arguments for ToDo (like "downsample_factor_depth_2=2")', + ) parser.add_argument( "--lowram", @@ -4186,6 +4198,10 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # apply token merging patch + if args.todo_factor: + token_merging.patch_attention(unet, args) + return text_encoder, vae, unet, load_stable_diffusion_format From f68c8bc23758e51b946e8a1a38871aa6c50a05d5 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Sun, 3 Mar 2024 23:49:10 +0100 Subject: [PATCH 2/8] --todo_factor: accept second factor for depth_2 --- library/token_merging.py | 29 ++++++++++++++++++++++++----- library/train_util.py | 3 ++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/library/token_merging.py b/library/token_merging.py index 1f53cda8..66ef539c 100644 --- a/library/token_merging.py +++ b/library/token_merging.py @@ -68,20 +68,39 @@ def hook_attention(attn: torch.nn.Module): attn._tome_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True)) -def patch_attention(unet: torch.nn.Module, args): - """ Patches the UNet's transformer blocks to apply token downsampling. """ - is_sdxl = isinstance(unet, SdxlUNet2DConditionModel) +def parse_todo_args(args, is_sdxl: bool) -> dict: + if len(args.todo_factor) > 2: + raise ValueError(f"--todo_factor expects 1 or 2 arguments, received {len(args.todo_factor)}") + elif is_sdxl and len(args.todo_factor) > 1: + raise ValueError(f"--todo_factor expects expects exactly 1 argument for SDXL, received {len(args.todo_factor)}") + todo_kwargs = { - "downsample_factor_depth_1": args.todo_factor, - "downsample_factor_depth_2": args.todo_factor if is_sdxl else 1, # SDXL doesn't have depth 1, so downsample here + "downsample_factor_depth_1": 1, + "downsample_factor_depth_2": 1, "downsample_method": "nearest-exact", } + + if is_sdxl: + # SDXL doesn't have depth 1, so default to depth 2 + todo_kwargs["downsample_factor_depth_2"] = args.todo_factor[0] + else: + todo_kwargs["downsample_factor_depth_1"] = args.todo_factor[0] + todo_kwargs["downsample_factor_depth_2"] = args.todo_factor[1] if len(args.todo_factor) == 2 else 1 + if args.todo_args: for arg in args.todo_args: key, value = arg.split("=") todo_kwargs[key] = ast.literal_eval(value) logger.info(f"enable token downsampling optimization | {todo_kwargs}") + return todo_kwargs + + +def patch_attention(unet: torch.nn.Module, args): + """ Patches the UNet's transformer blocks to apply token downsampling. """ + is_sdxl = isinstance(unet, SdxlUNet2DConditionModel) + todo_kwargs = parse_todo_args(args, is_sdxl) + unet._tome_info = { "size": None, "hooks": [], diff --git a/library/train_util.py b/library/train_util.py index 18cb1a21..da83791e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3139,7 +3139,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--todo_factor", type=float, - help="token downsampling (ToDo) factor > 1 (recommend starting with 2)", + nargs="+", + help="token downsampling (ToDo) factor > 1 (recommend around 2-4). SD1/2 accepts up to 2 values (for depth_1 and depth_2)", ) parser.add_argument( "--todo_args", From 8413112734973a782888c03938d93b44a66f7dd2 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:07:35 +0100 Subject: [PATCH 3/8] fix --todo_args parsing --- library/token_merging.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/token_merging.py b/library/token_merging.py index 66ef539c..028573f4 100644 --- a/library/token_merging.py +++ b/library/token_merging.py @@ -2,7 +2,6 @@ # https://github.com/ethansmith2000/ImprovedTokenMerge # https://github.com/ethansmith2000/comfy-todo (MIT) -import ast import math import torch @@ -90,7 +89,10 @@ def parse_todo_args(args, is_sdxl: bool) -> dict: if args.todo_args: for arg in args.todo_args: key, value = arg.split("=") - todo_kwargs[key] = ast.literal_eval(value) + todo_kwargs[key] = value + todo_kwargs["downsample_factor_depth_1"] = float(todo_kwargs["downsample_factor_depth_1"]) + todo_kwargs["downsample_factor_depth_2"] = float(todo_kwargs["downsample_factor_depth_2"]) + logger.info(f"enable token downsampling optimization | {todo_kwargs}") return todo_kwargs From c0e4c157754d24acd7dd303cc0975762fc37b63a Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:23:10 +0100 Subject: [PATCH 4/8] simplify ToDo args --- library/sdxl_train_util.py | 3 +- library/token_merging.py | 60 ++++++++++---------------------------- library/train_util.py | 15 ++++++---- 3 files changed, 27 insertions(+), 51 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index e5a69db5..1b4e3b74 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -60,7 +60,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): # apply token merging patch if args.todo_factor: - token_merging.patch_attention(unet, args) + token_merging.patch_attention(unet, args, is_sdxl=True) + logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/token_merging.py b/library/token_merging.py index 028573f4..b53a3c4f 100644 --- a/library/token_merging.py +++ b/library/token_merging.py @@ -7,15 +7,8 @@ import math import torch import torch.nn.functional as F -from library.sdxl_original_unet import SdxlUNet2DConditionModel -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - - -def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method): +def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"): batch_size = item.shape[0] item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2) @@ -31,18 +24,15 @@ def compute_merge(x: torch.Tensor, tome_info: dict): downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) cur_h = original_h // downsample cur_w = original_w // downsample - downsample_factor_1 = tome_info["args"]["downsample_factor_depth_1"] - downsample_factor_2 = tome_info["args"]["downsample_factor_depth_2"] + + args = tome_info["args"] + downsample_factor = args["downsample_factor"] merge_op = lambda x: x - if downsample == 1 and downsample_factor_1 > 1: - new_h = int(cur_h / downsample_factor_1) - new_w = int(cur_w / downsample_factor_1) - merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, tome_info["args"]["downsample_method"]) - elif downsample == 2 and downsample_factor_2 > 1: - new_h = int(cur_h / downsample_factor_2) - new_w = int(cur_w / downsample_factor_2) - merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, tome_info["args"]["downsample_method"]) + if downsample <= args["max_downsample"]: + new_h = int(cur_h / downsample_factor) + new_w = int(cur_w / downsample_factor) + merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h) return merge_op @@ -67,40 +57,22 @@ def hook_attention(attn: torch.nn.Module): attn._tome_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True)) -def parse_todo_args(args, is_sdxl: bool) -> dict: - if len(args.todo_factor) > 2: - raise ValueError(f"--todo_factor expects 1 or 2 arguments, received {len(args.todo_factor)}") - elif is_sdxl and len(args.todo_factor) > 1: - raise ValueError(f"--todo_factor expects expects exactly 1 argument for SDXL, received {len(args.todo_factor)}") +def parse_todo_args(args, is_sdxl: bool = False) -> dict: + if args.todo_max_downsample is None: + args.todo_max_downsample = 2 if is_sdxl else 1 + if is_sdxl and args.todo_max_downsample not in (2, 4): + raise ValueError(f"--todo_max_downsample for SDXL must be 2 or 4, received {args.todo_factor}") todo_kwargs = { - "downsample_factor_depth_1": 1, - "downsample_factor_depth_2": 1, - "downsample_method": "nearest-exact", + "downsample_factor": args.todo_factor, + "max_downsample": args.todo_max_downsample, } - if is_sdxl: - # SDXL doesn't have depth 1, so default to depth 2 - todo_kwargs["downsample_factor_depth_2"] = args.todo_factor[0] - else: - todo_kwargs["downsample_factor_depth_1"] = args.todo_factor[0] - todo_kwargs["downsample_factor_depth_2"] = args.todo_factor[1] if len(args.todo_factor) == 2 else 1 - - if args.todo_args: - for arg in args.todo_args: - key, value = arg.split("=") - todo_kwargs[key] = value - todo_kwargs["downsample_factor_depth_1"] = float(todo_kwargs["downsample_factor_depth_1"]) - todo_kwargs["downsample_factor_depth_2"] = float(todo_kwargs["downsample_factor_depth_2"]) - - logger.info(f"enable token downsampling optimization | {todo_kwargs}") - return todo_kwargs -def patch_attention(unet: torch.nn.Module, args): +def patch_attention(unet: torch.nn.Module, args, is_sdxl=False): """ Patches the UNet's transformer blocks to apply token downsampling. """ - is_sdxl = isinstance(unet, SdxlUNet2DConditionModel) todo_kwargs = parse_todo_args(args, is_sdxl) unet._tome_info = { diff --git a/library/train_util.py b/library/train_util.py index da83791e..505cc33a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3139,14 +3139,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--todo_factor", type=float, - nargs="+", - help="token downsampling (ToDo) factor > 1 (recommend around 2-4). SD1/2 accepts up to 2 values (for depth_1 and depth_2)", + help="token downsampling (ToDo) factor > 1 (recommend around 2-4)", ) parser.add_argument( - "--todo_args", - type=str, - nargs="*", - help='additional arguments for ToDo (like "downsample_factor_depth_2=2")', + "--todo_max_downsample", + type=int, + choices=[1, 2, 4, 8], + help=( + "apply ToDo to layers with at most this amount of downsampling." + " SDXL only accepts 2 and 4. Recommend 1 or 2. Default 1 (or 2 for SDXL)" + ), ) parser.add_argument( @@ -4202,6 +4204,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # apply token merging patch if args.todo_factor: token_merging.patch_attention(unet, args) + logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}") return text_encoder, vae, unet, load_stable_diffusion_format From 4fa42baafbe7255f0f1368ae47d9c618c0ed021f Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Tue, 5 Mar 2024 12:56:26 +0100 Subject: [PATCH 5/8] add ToDo settings to lora metadata --- train_network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_network.py b/train_network.py index e5b26d8a..053456bd 100644 --- a/train_network.py +++ b/train_network.py @@ -683,6 +683,10 @@ class NetworkTrainer: vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name + if args.todo_factor: + metadata["ss_todo_factor"] = args.todo_factor + metadata["ss_todo_max_downsample"] = args.todo_max_downsample + metadata = {k: str(v) for k, v in metadata.items()} # make minimum metadata for filtering From 6a22a5b65ddee371d14a56f5bc950800a464e70f Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Mon, 11 Mar 2024 00:11:53 +0100 Subject: [PATCH 6/8] clean up naming, replace max_downsample with max_depth to avoid confusion --- library/sdxl_train_util.py | 6 +-- ...token_merging.py => token_downsampling.py} | 42 +++++++++---------- library/train_util.py | 13 +++--- train_network.py | 2 +- 4 files changed, 31 insertions(+), 32 deletions(-) rename library/{token_merging.py => token_downsampling.py} (66%) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1b4e3b74..365f8fa2 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,7 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline -from library import token_merging +from library import token_downsampling from .utils import setup_logging setup_logging() import logging @@ -60,8 +60,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): # apply token merging patch if args.todo_factor: - token_merging.patch_attention(unet, args, is_sdxl=True) - logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}") + token_downsampling.apply_patch(unet, args, is_sdxl=True) + logger.info(f"enable token downsampling optimization | {unet._todo_info['args']}") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/token_merging.py b/library/token_downsampling.py similarity index 66% rename from library/token_merging.py rename to library/token_downsampling.py index b53a3c4f..9b64d3d9 100644 --- a/library/token_merging.py +++ b/library/token_downsampling.py @@ -18,18 +18,18 @@ def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"): return item -def compute_merge(x: torch.Tensor, tome_info: dict): - original_h, original_w = tome_info["size"] +def compute_merge(x: torch.Tensor, todo_info: dict): + original_h, original_w = todo_info["size"] original_tokens = original_h * original_w downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) cur_h = original_h // downsample cur_w = original_w // downsample - args = tome_info["args"] + args = todo_info["args"] downsample_factor = args["downsample_factor"] merge_op = lambda x: x - if downsample <= args["max_downsample"]: + if downsample <= args["max_depth"]: new_h = int(cur_h / downsample_factor) new_w = int(cur_w / downsample_factor) merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h) @@ -37,63 +37,63 @@ def compute_merge(x: torch.Tensor, tome_info: dict): return merge_op -def hook_tome_model(model: torch.nn.Module): +def hook_unet(model: torch.nn.Module): """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ def hook(module, args): - module._tome_info["size"] = (args[0].shape[2], args[0].shape[3]) + module._todo_info["size"] = (args[0].shape[2], args[0].shape[3]) return None - model._tome_info["hooks"].append(model.register_forward_pre_hook(hook)) + model._todo_info["hooks"].append(model.register_forward_pre_hook(hook)) def hook_attention(attn: torch.nn.Module): """ Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """ def hook(module, args, kwargs): hidden_states = args[0] - m = compute_merge(hidden_states, module._tome_info) + m = compute_merge(hidden_states, module._todo_info) kwargs["context"] = m(hidden_states) return args, kwargs - attn._tome_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True)) + attn._todo_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True)) def parse_todo_args(args, is_sdxl: bool = False) -> dict: - if args.todo_max_downsample is None: - args.todo_max_downsample = 2 if is_sdxl else 1 - if is_sdxl and args.todo_max_downsample not in (2, 4): - raise ValueError(f"--todo_max_downsample for SDXL must be 2 or 4, received {args.todo_factor}") + if args.todo_max_depth is None: + args.todo_max_depth = 2 if is_sdxl else 1 + if is_sdxl and args.todo_max_depth not in (2, 3): + raise ValueError(f"--todo_max_depth for SDXL must be 2 or 3, received {args.todo_factor}") todo_kwargs = { "downsample_factor": args.todo_factor, - "max_downsample": args.todo_max_downsample, + "max_depth": 2**(args.todo_max_depth - 1), } return todo_kwargs -def patch_attention(unet: torch.nn.Module, args, is_sdxl=False): +def apply_patch(unet: torch.nn.Module, args, is_sdxl=False): """ Patches the UNet's transformer blocks to apply token downsampling. """ todo_kwargs = parse_todo_args(args, is_sdxl) - unet._tome_info = { + unet._todo_info = { "size": None, "hooks": [], "args": todo_kwargs, } - hook_tome_model(unet) + hook_unet(unet) for _, module in unet.named_modules(): if module.__class__.__name__ == "BasicTransformerBlock": - module.attn1._tome_info = unet._tome_info + module.attn1._todo_info = unet._todo_info hook_attention(module.attn1) return unet def remove_patch(unet: torch.nn.Module): - if hasattr(unet, "_tome_info"): - for hook in unet._tome_info["hooks"]: + if hasattr(unet, "_todo_info"): + for hook in unet._todo_info["hooks"]: hook.remove() - unet._tome_info["hooks"].clear() + unet._todo_info["hooks"].clear() return unet diff --git a/library/train_util.py b/library/train_util.py index 505cc33a..12453318 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -69,7 +69,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec -from library import token_merging +from library import token_downsampling from library.utils import setup_logging setup_logging() @@ -3142,12 +3142,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="token downsampling (ToDo) factor > 1 (recommend around 2-4)", ) parser.add_argument( - "--todo_max_downsample", + "--todo_max_depth", type=int, - choices=[1, 2, 4, 8], + choices=[1, 2, 3, 4], help=( - "apply ToDo to layers with at most this amount of downsampling." - " SDXL only accepts 2 and 4. Recommend 1 or 2. Default 1 (or 2 for SDXL)" + "apply ToDo to deeper layers (lower quailty for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)" ), ) @@ -4203,8 +4202,8 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # apply token merging patch if args.todo_factor: - token_merging.patch_attention(unet, args) - logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}") + token_downsampling.apply_patch(unet, args) + logger.info(f"enable token downsampling optimization | {unet._todo_info['args']}") return text_encoder, vae, unet, load_stable_diffusion_format diff --git a/train_network.py b/train_network.py index 053456bd..276ab3e7 100644 --- a/train_network.py +++ b/train_network.py @@ -685,7 +685,7 @@ class NetworkTrainer: if args.todo_factor: metadata["ss_todo_factor"] = args.todo_factor - metadata["ss_todo_max_downsample"] = args.todo_max_downsample + metadata["ss_todo_max_depth"] = args.todo_max_depth metadata = {k: str(v) for k, v in metadata.items()} From 1d351448c0d8b75fa253262c251cb2a4f5677d87 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Mon, 11 Mar 2024 00:58:28 +0100 Subject: [PATCH 7/8] fix typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 12453318..40d6e671 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3146,7 +3146,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: type=int, choices=[1, 2, 3, 4], help=( - "apply ToDo to deeper layers (lower quailty for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)" + "apply ToDo to deeper layers (lower quality for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)" ), ) From 2629117eacc81975b2a55b471d16212e54a1cd79 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:00:07 +0100 Subject: [PATCH 8/8] todo_factor: accept multiple values (per depth). autodetect todo_max_factor --- library/sdxl_train_util.py | 2 +- library/token_downsampling.py | 28 ++++++++++++++++++++++------ library/train_util.py | 5 +++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 365f8fa2..7ce19966 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -61,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): # apply token merging patch if args.todo_factor: token_downsampling.apply_patch(unet, args, is_sdxl=True) - logger.info(f"enable token downsampling optimization | {unet._todo_info['args']}") + logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/token_downsampling.py b/library/token_downsampling.py index 9b64d3d9..0e1b0138 100644 --- a/library/token_downsampling.py +++ b/library/token_downsampling.py @@ -26,10 +26,10 @@ def compute_merge(x: torch.Tensor, todo_info: dict): cur_w = original_w // downsample args = todo_info["args"] - downsample_factor = args["downsample_factor"] merge_op = lambda x: x if downsample <= args["max_depth"]: + downsample_factor = args["downsample_factor"][downsample] new_h = int(cur_h / downsample_factor) new_w = int(cur_w / downsample_factor) merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h) @@ -58,14 +58,30 @@ def hook_attention(attn: torch.nn.Module): def parse_todo_args(args, is_sdxl: bool = False) -> dict: + # validate max_depth if args.todo_max_depth is None: - args.todo_max_depth = 2 if is_sdxl else 1 - if is_sdxl and args.todo_max_depth not in (2, 3): - raise ValueError(f"--todo_max_depth for SDXL must be 2 or 3, received {args.todo_factor}") + args.todo_max_depth = min(len(args.todo_factor), 4) + if is_sdxl and args.todo_max_depth > 2: + raise ValueError(f"todo_max_depth for SDXL cannot be larger than 2, received {args.todo_max_depth}") + + # validate factor + if len(args.todo_factor) > 1: + if len(args.todo_factor) != args.todo_max_depth: + raise ValueError(f"todo_factor number of values must be 1 or same as todo_max_depth, received {len(args.todo_factor)}") + + # create dict of factors to support per-depth override + factors = args.todo_factor + if len(factors) == 1: + factors *= args.todo_max_depth + factors = {2**(i + int(is_sdxl)): factor for i, factor in enumerate(factors)} + + # convert depth to powers of 2 to match layer dimensions: [1,2,3,4] -> [1,2,4,8] + # offset by 1 for sdxl which starts at 2 + max_depth = 2**(args.todo_max_depth + int(is_sdxl) - 1) todo_kwargs = { - "downsample_factor": args.todo_factor, - "max_depth": 2**(args.todo_max_depth - 1), + "downsample_factor": factors, + "max_depth": max_depth, } return todo_kwargs diff --git a/library/train_util.py b/library/train_util.py index 40d6e671..0139896a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3139,7 +3139,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--todo_factor", type=float, - help="token downsampling (ToDo) factor > 1 (recommend around 2-4)", + nargs="+", + help="token downsampling (ToDo) factor > 1 (recommend around 2-4). Specify multiple to set factor for each depth", ) parser.add_argument( "--todo_max_depth", @@ -4203,7 +4204,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # apply token merging patch if args.todo_factor: token_downsampling.apply_patch(unet, args) - logger.info(f"enable token downsampling optimization | {unet._todo_info['args']}") + logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}") return text_encoder, vae, unet, load_stable_diffusion_format