diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f78d9424..3ad6e7cd 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -13,6 +13,7 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from library import token_downsampling from .utils import setup_logging setup_logging() @@ -60,6 +61,11 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # apply token merging patch + if args.todo_factor: + token_downsampling.apply_patch(unet, args, is_sdxl=True) + logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}") + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/token_downsampling.py b/library/token_downsampling.py new file mode 100644 index 00000000..0e1b0138 --- /dev/null +++ b/library/token_downsampling.py @@ -0,0 +1,115 @@ +# based on: +# https://github.com/ethansmith2000/ImprovedTokenMerge +# https://github.com/ethansmith2000/comfy-todo (MIT) + +import math + +import torch +import torch.nn.functional as F + + +def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"): + batch_size = item.shape[0] + + item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2) + item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1) + item = item.reshape(batch_size, new_h * new_w, -1) + + return item + + +def compute_merge(x: torch.Tensor, todo_info: dict): + original_h, original_w = todo_info["size"] + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + cur_h = original_h // downsample + cur_w = original_w // downsample + + args = todo_info["args"] + + merge_op = lambda x: x + if downsample <= args["max_depth"]: + downsample_factor = args["downsample_factor"][downsample] + new_h = int(cur_h / downsample_factor) + new_w = int(cur_w / downsample_factor) + merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h) + + return merge_op + + +def hook_unet(model: torch.nn.Module): + """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ + def hook(module, args): + module._todo_info["size"] = (args[0].shape[2], args[0].shape[3]) + return None + + model._todo_info["hooks"].append(model.register_forward_pre_hook(hook)) + + +def hook_attention(attn: torch.nn.Module): + """ Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """ + def hook(module, args, kwargs): + hidden_states = args[0] + m = compute_merge(hidden_states, module._todo_info) + kwargs["context"] = m(hidden_states) + return args, kwargs + + attn._todo_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True)) + + +def parse_todo_args(args, is_sdxl: bool = False) -> dict: + # validate max_depth + if args.todo_max_depth is None: + args.todo_max_depth = min(len(args.todo_factor), 4) + if is_sdxl and args.todo_max_depth > 2: + raise ValueError(f"todo_max_depth for SDXL cannot be larger than 2, received {args.todo_max_depth}") + + # validate factor + if len(args.todo_factor) > 1: + if len(args.todo_factor) != args.todo_max_depth: + raise ValueError(f"todo_factor number of values must be 1 or same as todo_max_depth, received {len(args.todo_factor)}") + + # create dict of factors to support per-depth override + factors = args.todo_factor + if len(factors) == 1: + factors *= args.todo_max_depth + factors = {2**(i + int(is_sdxl)): factor for i, factor in enumerate(factors)} + + # convert depth to powers of 2 to match layer dimensions: [1,2,3,4] -> [1,2,4,8] + # offset by 1 for sdxl which starts at 2 + max_depth = 2**(args.todo_max_depth + int(is_sdxl) - 1) + + todo_kwargs = { + "downsample_factor": factors, + "max_depth": max_depth, + } + + return todo_kwargs + + +def apply_patch(unet: torch.nn.Module, args, is_sdxl=False): + """ Patches the UNet's transformer blocks to apply token downsampling. """ + todo_kwargs = parse_todo_args(args, is_sdxl) + + unet._todo_info = { + "size": None, + "hooks": [], + "args": todo_kwargs, + } + hook_unet(unet) + + for _, module in unet.named_modules(): + if module.__class__.__name__ == "BasicTransformerBlock": + module.attn1._todo_info = unet._todo_info + hook_attention(module.attn1) + + return unet + + +def remove_patch(unet: torch.nn.Module): + if hasattr(unet, "_todo_info"): + for hook in unet._todo_info["hooks"]: + hook.remove() + unet._todo_info["hooks"].clear() + + return unet diff --git a/library/train_util.py b/library/train_util.py index fd46f905..84d3b04b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -74,6 +74,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +from library import token_downsampling import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging, pil_resize @@ -3483,6 +3484,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=0.1, help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", ) + parser.add_argument( + "--todo_factor", + type=float, + nargs="+", + help="token downsampling (ToDo) factor > 1 (recommend around 2-4). Specify multiple to set factor for each depth", + ) + parser.add_argument( + "--todo_max_depth", + type=int, + choices=[1, 2, 3, 4], + help=( + "apply ToDo to deeper layers (lower quality for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)" + ), + ) parser.add_argument( "--lowram", @@ -4736,6 +4751,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + + # apply token merging patch + if args.todo_factor: + token_downsampling.apply_patch(unet, args) + logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}") + return text_encoder, vae, unet, load_stable_diffusion_format diff --git a/train_network.py b/train_network.py index 6953bb17..5baf6c43 100644 --- a/train_network.py +++ b/train_network.py @@ -767,6 +767,10 @@ class NetworkTrainer: vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name + if args.todo_factor: + metadata["ss_todo_factor"] = args.todo_factor + metadata["ss_todo_max_depth"] = args.todo_max_depth + metadata = {k: str(v) for k, v in metadata.items()} # make minimum metadata for filtering