support SD3 LoRA

This commit is contained in:
kohya-ss
2024-10-25 21:58:31 +09:00
parent f52fb66e8f
commit d2c549d7b2
7 changed files with 1334 additions and 67 deletions

View File

@@ -761,6 +761,9 @@ class MMDiT(nn.Module):
self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
# self.initialize_weights() # self.initialize_weights()
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
@property @property
def model_type(self): def model_type(self):
return self._model_type return self._model_type

View File

@@ -198,6 +198,23 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用", help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
) )
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=256,
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
# copy from Diffusers # copy from Diffusers
parser.add_argument( parser.add_argument(
"--weighting_scheme", "--weighting_scheme",
@@ -317,7 +334,7 @@ def do_sample(
x = noise_scaled.to(device).to(dtype) x = noise_scaled.to(device).to(dtype)
# print(x.shape) # print(x.shape)
with torch.no_grad(): # with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)): for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i] sigma_hat = sigmas[i]
@@ -378,7 +395,7 @@ def sample_images(
logger.info("") logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts): if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return return
@@ -386,7 +403,7 @@ def sample_images(
# unwrap unet and text_encoder(s) # unwrap unet and text_encoder(s)
mmdit = accelerator.unwrap_model(mmdit) mmdit = accelerator.unwrap_model(mmdit)
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts) prompts = train_util.load_prompts(args.sample_prompts)
@@ -404,7 +421,7 @@ def sample_images(
if distributed_state.num_processes <= 1: if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(): with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts: for prompt_dict in prompts:
sample_image_inference( sample_image_inference(
accelerator, accelerator,
@@ -506,29 +523,39 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: def encode_prompt(prpt):
te_outputs = sample_prompts_te_outputs[prompt] text_encoder_conds = []
else: if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt) text_encoder_conds = sample_prompts_te_outputs[prpt]
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens]) print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs # if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts # encode negative prompts
if sample_prompts_te_outputs and negative_prompt in sample_prompts_te_outputs: lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
neg_te_outputs = sample_prompts_te_outputs[negative_prompt]
else:
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt)
neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image # sample image
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
with accelerator.autocast(): with accelerator.autocast(), torch.no_grad():
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, mmdit.dtype, accelerator.device) # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
# latent to image # latent to image
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)

View File

@@ -91,7 +91,7 @@ def load_mmdit(
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
logger.info("Loading state dict...") logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
logger.info(f"Loaded MMDiT: {info}") logger.info(f"Loaded MMDiT: {info}")
return mmdit return mmdit

826
networks/lora_sd3.py Normal file
View File

@@ -0,0 +1,826 @@
# temporary minimum implementation of LoRA
# SD3 doesn't have Conv2d, so we ignore it
# TODO commonize with the original/SD3/FLUX implementation
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from transformers import CLIPTextModelWithProjection, T5EncoderModel
import numpy as np
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
from library import sd3_models
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: sd3_models.SDVAE,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
mmdit,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
context_attn_dim = kwargs.get("context_attn_dim", None)
context_mlp_dim = kwargs.get("context_mlp_dim", None)
context_mod_dim = kwargs.get("context_mod_dim", None)
x_attn_dim = kwargs.get("x_attn_dim", None)
x_mlp_dim = kwargs.get("x_mlp_dim", None)
x_mod_dim = kwargs.get("x_mod_dim", None)
if context_attn_dim is not None:
context_attn_dim = int(context_attn_dim)
if context_mlp_dim is not None:
context_mlp_dim = int(context_mlp_dim)
if context_mod_dim is not None:
context_mod_dim = int(context_mod_dim)
if x_attn_dim is not None:
x_attn_dim = int(x_attn_dim)
if x_mlp_dim is not None:
x_mlp_dim = int(x_mlp_dim)
if x_mod_dim is not None:
x_mod_dim = int(x_mod_dim)
type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval?
assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)"
# double/single train blocks
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
"""
Parse a block selection string and return a list of booleans.
Args:
selection (str): A string specifying which blocks to select.
total_blocks (int): The total number of blocks available.
Returns:
List[bool]: A list of booleans indicating which blocks are selected.
"""
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start = int(start)
end = int(end)
assert 0 <= start < total_blocks, f"invalid start index: {start}"
assert 0 <= end < total_blocks, f"invalid end index: {end}"
assert start <= end, f"invalid range: {start}-{end}"
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks, f"invalid index: {index}"
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
train_t5xxl = None
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None:
train_t5xxl = False
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
def __init__(
self,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
unet: sd3_models.MMDiT,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.emb_dims = [0] * 6 # create emb_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
qkv_dim = 0
if self.split_qkv:
logger.info(f"split qkv for LoRA")
qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0)
if train_t5xxl:
logger.info(f"train T5XXL as well")
# create module instances
def create_modules(
is_mmdit: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_SD3
if is_mmdit
else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][
text_encoder_idx
]
)
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
if filter is not None and not filter in lora_name:
continue
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if is_mmdit and type_dims is not None:
# type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
identifier = [
("context_block", "attn"),
("context_block", "mlp"),
("context_block", "adaLN_modulation"),
("x_block", "attn"),
("x_block", "mlp"),
("x_block", "adaLN_modulation"),
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name:
# "lora_unet_joint_blocks_0_x_block_attn_proj..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if self.train_block_indices is not None and not self.train_block_indices[block_index]:
dim = 0
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
skipped.append(lora_name)
continue
# qkv split
split_dims = None
if is_mmdit and split_qkv:
if "joint_blocks" in lora_name and "qkv" in lora_name:
split_dims = [qkv_dim // 3] * 3
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
return loras, skipped
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
# create LoRA for U-Net
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE)
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
if self.emb_dims:
for filter, in_dim in zip(
[
"context_embedder",
"t_embedder",
"x_embedder",
"y_embedder",
"final_layer_adaLN_modulation",
"final_layer_linear",
],
self.emb_dims,
):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to split qkv
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)
# split qkv
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
continue
weight = state_dict[key]
lora_name = key.split(".")[0]
if "lora_down" in key and "weight" in key:
# dense weight (rank*3, in_dim)
split_weight = torch.chunk(weight, 3, dim=0)
for i, split_w in enumerate(split_weight):
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
del state_dict[key]
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
elif "lora_up" in key and "weight" in key:
# sparse weight (out_dim=sum(split_dims), rank*3)
rank = weight.size(1) // 3
i = 0
split_dim = weight.shape[0] // 3
for j in range(3):
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank]
i += split_dim
del state_dict[key]
# alpha is unchanged
return super().load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
qkv_dim, rank = up_weights[0].size()
split_dim = qkv_dim // 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(3):
up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j]
i += split_dim
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict
def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if (
key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5)
):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_MMDIT):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of three elements
# if float, use the same value for all three
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr, default_lr, default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]]
elif len(text_encoder_lr) == 2:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]]
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# split text encoder loras for te1 and te3
te1_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
]
te2_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
]
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te2_loras) > 0:
logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}")
params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te3_loras) > 0:
logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}")
params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -220,12 +220,7 @@ def train(args):
sd3_state_dict = None sd3_state_dict = None
# load tokenizer and prepare tokenize strategy # load tokenizer and prepare tokenize strategy
if args.t5xxl_max_token_length is None: sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
t5xxl_max_token_length = 256 # default value for T5XXL
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)
# load clip_l, clip_g, t5xxl for caching text encoder outputs # load clip_l, clip_g, t5xxl for caching text encoder outputs
@@ -876,6 +871,9 @@ def train(args):
lg_out = None lg_out = None
t5_out = None t5_out = None
lg_pooled = None lg_pooled = None
l_attn_mask = None
g_attn_mask = None
t5_attn_mask = None
if lg_out is None: if lg_out is None:
# not cached or training, so get from text encoders # not cached or training, so get from text encoders
@@ -885,7 +883,7 @@ def train(args):
# text models in sd3_models require "cpu" for input_ids # text models in sd3_models require "cpu" for input_ids
input_ids_clip_l = input_ids_clip_l.to("cpu") input_ids_clip_l = input_ids_clip_l.to("cpu")
input_ids_clip_g = input_ids_clip_g.to("cpu") input_ids_clip_g = input_ids_clip_g.to("cpu")
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, sd3_tokenize_strategy,
[clip_l, clip_g, None], [clip_l, clip_g, None],
[input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None],
@@ -895,7 +893,7 @@ def train(args):
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad(): with torch.no_grad():
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
_, t5_out, _ = text_encoding_strategy.encode_tokens( _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
) )
@@ -1104,22 +1102,6 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする"
) )
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument( parser.add_argument(
"--learning_rate_te1", "--learning_rate_te1",

427
sd3_train_network.py Normal file
View File

@@ -0,0 +1,427 @@
import argparse
import copy
import math
import random
from typing import Any, Optional
import torch
from accelerate import Accelerator
from library import strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
state_dict = utils.load_safetensors(
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
self.model_type = mmdit.model_type
if args.fp8_base:
# check dtype of model
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}")
elif mmdit.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 SD3 model")
clip_l = sd3_utils.load_clip_l(
args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_l.eval()
clip_g = sd3_utils.load_clip_g(
args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_g.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
vae = sd3_utils.load_vae(
args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit
def get_tokenize_strategy(self, args):
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.train_clip and not self.train_t5xxl:
return text_encoders[0:2] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
else:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_clip, self.train_clip, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
apply_lg_attn_mask=args.apply_lg_attn_mask,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[2].to(accelerator.device) # may be fp8
if text_encoders[2].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[2].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move CLIP-G back to cpu")
text_encoders[1].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[2].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[2].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
sd3_train_utils.sample_images(
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# shift 3.0 is the default value
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: flux_models.Flux,
network,
weight_dtype,
train_unet,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
t.requires_grad_(True)
# Predict the noise residual
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
if not args.apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
if not args.apply_t5_attn_mask:
t5_attn_mask = None
# call model
with accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = latents
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
context=context[diff_output_pr_indices],
y=lg_pooled[diff_output_pr_indices],
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices]
# weighting for differential output preservation is not needed because it is already applied
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, None, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type)
def update_metadata(self, metadata, args):
metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0 or index == 1: # CLIP-L/CLIP-G
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0 or index == 1: # CLIP-L/CLIP-G
clip_type = "CLIP-L" if index == 0 else "CLIP-G"
logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
sd3_train_utils.add_sd3_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = Sd3NetworkTrainer()
trainer.train(args)

View File

@@ -129,6 +129,7 @@ class NetworkTrainer:
def get_models_for_text_encoding(self, args, accelerator, text_encoders): def get_models_for_text_encoding(self, args, accelerator, text_encoders):
""" """
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
""" """
return text_encoders return text_encoders
@@ -591,6 +592,7 @@ class NetworkTrainer:
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
unet.requires_grad_(False) unet.requires_grad_(False)