This commit is contained in:
Hope I.
2025-06-26 17:31:52 +00:00
committed by GitHub
4 changed files with 146 additions and 0 deletions

View File

@@ -13,6 +13,7 @@ from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from library import token_downsampling
from .utils import setup_logging
setup_logging()
@@ -60,6 +61,11 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# apply token merging patch
if args.todo_factor:
token_downsampling.apply_patch(unet, args, is_sdxl=True)
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -0,0 +1,115 @@
# based on:
# https://github.com/ethansmith2000/ImprovedTokenMerge
# https://github.com/ethansmith2000/comfy-todo (MIT)
import math
import torch
import torch.nn.functional as F
def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"):
batch_size = item.shape[0]
item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2)
item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1)
item = item.reshape(batch_size, new_h * new_w, -1)
return item
def compute_merge(x: torch.Tensor, todo_info: dict):
original_h, original_w = todo_info["size"]
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
cur_h = original_h // downsample
cur_w = original_w // downsample
args = todo_info["args"]
merge_op = lambda x: x
if downsample <= args["max_depth"]:
downsample_factor = args["downsample_factor"][downsample]
new_h = int(cur_h / downsample_factor)
new_w = int(cur_w / downsample_factor)
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h)
return merge_op
def hook_unet(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
def hook(module, args):
module._todo_info["size"] = (args[0].shape[2], args[0].shape[3])
return None
model._todo_info["hooks"].append(model.register_forward_pre_hook(hook))
def hook_attention(attn: torch.nn.Module):
""" Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """
def hook(module, args, kwargs):
hidden_states = args[0]
m = compute_merge(hidden_states, module._todo_info)
kwargs["context"] = m(hidden_states)
return args, kwargs
attn._todo_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True))
def parse_todo_args(args, is_sdxl: bool = False) -> dict:
# validate max_depth
if args.todo_max_depth is None:
args.todo_max_depth = min(len(args.todo_factor), 4)
if is_sdxl and args.todo_max_depth > 2:
raise ValueError(f"todo_max_depth for SDXL cannot be larger than 2, received {args.todo_max_depth}")
# validate factor
if len(args.todo_factor) > 1:
if len(args.todo_factor) != args.todo_max_depth:
raise ValueError(f"todo_factor number of values must be 1 or same as todo_max_depth, received {len(args.todo_factor)}")
# create dict of factors to support per-depth override
factors = args.todo_factor
if len(factors) == 1:
factors *= args.todo_max_depth
factors = {2**(i + int(is_sdxl)): factor for i, factor in enumerate(factors)}
# convert depth to powers of 2 to match layer dimensions: [1,2,3,4] -> [1,2,4,8]
# offset by 1 for sdxl which starts at 2
max_depth = 2**(args.todo_max_depth + int(is_sdxl) - 1)
todo_kwargs = {
"downsample_factor": factors,
"max_depth": max_depth,
}
return todo_kwargs
def apply_patch(unet: torch.nn.Module, args, is_sdxl=False):
""" Patches the UNet's transformer blocks to apply token downsampling. """
todo_kwargs = parse_todo_args(args, is_sdxl)
unet._todo_info = {
"size": None,
"hooks": [],
"args": todo_kwargs,
}
hook_unet(unet)
for _, module in unet.named_modules():
if module.__class__.__name__ == "BasicTransformerBlock":
module.attn1._todo_info = unet._todo_info
hook_attention(module.attn1)
return unet
def remove_patch(unet: torch.nn.Module):
if hasattr(unet, "_todo_info"):
for hook in unet._todo_info["hooks"]:
hook.remove()
unet._todo_info["hooks"].clear()
return unet

View File

@@ -74,6 +74,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
from library import token_downsampling
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
@@ -3483,6 +3484,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--todo_factor",
type=float,
nargs="+",
help="token downsampling (ToDo) factor > 1 (recommend around 2-4). Specify multiple to set factor for each depth",
)
parser.add_argument(
"--todo_max_depth",
type=int,
choices=[1, 2, 3, 4],
help=(
"apply ToDo to deeper layers (lower quality for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)"
),
)
parser.add_argument(
"--lowram",
@@ -4736,6 +4751,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# apply token merging patch
if args.todo_factor:
token_downsampling.apply_patch(unet, args)
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")
return text_encoder, vae, unet, load_stable_diffusion_format

View File

@@ -767,6 +767,10 @@ class NetworkTrainer:
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
if args.todo_factor:
metadata["ss_todo_factor"] = args.todo_factor
metadata["ss_todo_max_depth"] = args.todo_max_depth
metadata = {k: str(v) for k, v in metadata.items()}
# make minimum metadata for filtering