mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge f506a0eed8 into cadcd3169b
This commit is contained in:
@@ -13,6 +13,7 @@ from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
from library import token_downsampling
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -60,6 +61,11 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# apply token merging patch
|
||||
if args.todo_factor:
|
||||
token_downsampling.apply_patch(unet, args, is_sdxl=True)
|
||||
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
|
||||
115
library/token_downsampling.py
Normal file
115
library/token_downsampling.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# based on:
|
||||
# https://github.com/ethansmith2000/ImprovedTokenMerge
|
||||
# https://github.com/ethansmith2000/comfy-todo (MIT)
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"):
|
||||
batch_size = item.shape[0]
|
||||
|
||||
item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2)
|
||||
item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1)
|
||||
item = item.reshape(batch_size, new_h * new_w, -1)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def compute_merge(x: torch.Tensor, todo_info: dict):
|
||||
original_h, original_w = todo_info["size"]
|
||||
original_tokens = original_h * original_w
|
||||
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||
cur_h = original_h // downsample
|
||||
cur_w = original_w // downsample
|
||||
|
||||
args = todo_info["args"]
|
||||
|
||||
merge_op = lambda x: x
|
||||
if downsample <= args["max_depth"]:
|
||||
downsample_factor = args["downsample_factor"][downsample]
|
||||
new_h = int(cur_h / downsample_factor)
|
||||
new_w = int(cur_w / downsample_factor)
|
||||
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h)
|
||||
|
||||
return merge_op
|
||||
|
||||
|
||||
def hook_unet(model: torch.nn.Module):
|
||||
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
|
||||
def hook(module, args):
|
||||
module._todo_info["size"] = (args[0].shape[2], args[0].shape[3])
|
||||
return None
|
||||
|
||||
model._todo_info["hooks"].append(model.register_forward_pre_hook(hook))
|
||||
|
||||
|
||||
def hook_attention(attn: torch.nn.Module):
|
||||
""" Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """
|
||||
def hook(module, args, kwargs):
|
||||
hidden_states = args[0]
|
||||
m = compute_merge(hidden_states, module._todo_info)
|
||||
kwargs["context"] = m(hidden_states)
|
||||
return args, kwargs
|
||||
|
||||
attn._todo_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True))
|
||||
|
||||
|
||||
def parse_todo_args(args, is_sdxl: bool = False) -> dict:
|
||||
# validate max_depth
|
||||
if args.todo_max_depth is None:
|
||||
args.todo_max_depth = min(len(args.todo_factor), 4)
|
||||
if is_sdxl and args.todo_max_depth > 2:
|
||||
raise ValueError(f"todo_max_depth for SDXL cannot be larger than 2, received {args.todo_max_depth}")
|
||||
|
||||
# validate factor
|
||||
if len(args.todo_factor) > 1:
|
||||
if len(args.todo_factor) != args.todo_max_depth:
|
||||
raise ValueError(f"todo_factor number of values must be 1 or same as todo_max_depth, received {len(args.todo_factor)}")
|
||||
|
||||
# create dict of factors to support per-depth override
|
||||
factors = args.todo_factor
|
||||
if len(factors) == 1:
|
||||
factors *= args.todo_max_depth
|
||||
factors = {2**(i + int(is_sdxl)): factor for i, factor in enumerate(factors)}
|
||||
|
||||
# convert depth to powers of 2 to match layer dimensions: [1,2,3,4] -> [1,2,4,8]
|
||||
# offset by 1 for sdxl which starts at 2
|
||||
max_depth = 2**(args.todo_max_depth + int(is_sdxl) - 1)
|
||||
|
||||
todo_kwargs = {
|
||||
"downsample_factor": factors,
|
||||
"max_depth": max_depth,
|
||||
}
|
||||
|
||||
return todo_kwargs
|
||||
|
||||
|
||||
def apply_patch(unet: torch.nn.Module, args, is_sdxl=False):
|
||||
""" Patches the UNet's transformer blocks to apply token downsampling. """
|
||||
todo_kwargs = parse_todo_args(args, is_sdxl)
|
||||
|
||||
unet._todo_info = {
|
||||
"size": None,
|
||||
"hooks": [],
|
||||
"args": todo_kwargs,
|
||||
}
|
||||
hook_unet(unet)
|
||||
|
||||
for _, module in unet.named_modules():
|
||||
if module.__class__.__name__ == "BasicTransformerBlock":
|
||||
module.attn1._todo_info = unet._todo_info
|
||||
hook_attention(module.attn1)
|
||||
|
||||
return unet
|
||||
|
||||
|
||||
def remove_patch(unet: torch.nn.Module):
|
||||
if hasattr(unet, "_todo_info"):
|
||||
for hook in unet._todo_info["hooks"]:
|
||||
hook.remove()
|
||||
unet._todo_info["hooks"].clear()
|
||||
|
||||
return unet
|
||||
@@ -74,6 +74,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
from library import token_downsampling
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging, pil_resize
|
||||
|
||||
@@ -3483,6 +3484,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=0.1,
|
||||
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--todo_factor",
|
||||
type=float,
|
||||
nargs="+",
|
||||
help="token downsampling (ToDo) factor > 1 (recommend around 2-4). Specify multiple to set factor for each depth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--todo_max_depth",
|
||||
type=int,
|
||||
choices=[1, 2, 3, 4],
|
||||
help=(
|
||||
"apply ToDo to deeper layers (lower quality for slight speed increase). SDXL only accepts 2 and 3. Recommend 1 or 2. Default 1 (or 2 for SDXL)"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lowram",
|
||||
@@ -4736,6 +4751,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# apply token merging patch
|
||||
if args.todo_factor:
|
||||
token_downsampling.apply_patch(unet, args)
|
||||
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
|
||||
@@ -767,6 +767,10 @@ class NetworkTrainer:
|
||||
vae_name = os.path.basename(vae_name)
|
||||
metadata["ss_vae_name"] = vae_name
|
||||
|
||||
if args.todo_factor:
|
||||
metadata["ss_todo_factor"] = args.todo_factor
|
||||
metadata["ss_todo_max_depth"] = args.todo_max_depth
|
||||
|
||||
metadata = {k: str(v) for k, v in metadata.items()}
|
||||
|
||||
# make minimum metadata for filtering
|
||||
|
||||
Reference in New Issue
Block a user