mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add experimental split mode for FLUX
This commit is contained in:
22
README.md
22
README.md
@@ -4,12 +4,22 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||
|
||||
__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
|
||||
|
||||
Aug 13, 2024:
|
||||
|
||||
__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`.
|
||||
|
||||
This argument is available even if `--split_mode` is not specified.
|
||||
|
||||
__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments.
|
||||
|
||||
This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default.
|
||||
|
||||
Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.
|
||||
|
||||
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
|
||||
|
||||
__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
|
||||
|
||||
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs.
|
||||
|
||||
```
|
||||
@@ -19,7 +29,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t
|
||||
The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below:
|
||||
|
||||
```
|
||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"`
|
||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
|
||||
```
|
||||
|
||||
The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below:
|
||||
|
||||
```
|
||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single"
|
||||
```
|
||||
|
||||
LoRAs for Text Encoders are not tested yet.
|
||||
|
||||
@@ -37,10 +37,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
|
||||
if args.split_mode:
|
||||
model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
|
||||
clip_l.eval()
|
||||
@@ -49,13 +55,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
|
||||
t5xxl.eval()
|
||||
|
||||
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
|
||||
def prepare_split_model(self, model, weight_dtype, accelerator):
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
logger.info("prepare split model")
|
||||
with init_empty_weights():
|
||||
flux_upper = flux_models.FluxUpper(model.params)
|
||||
flux_lower = flux_models.FluxLower(model.params)
|
||||
sd = model.state_dict()
|
||||
|
||||
# lower (trainable)
|
||||
logger.info("load state dict for lower")
|
||||
flux_lower.load_state_dict(sd, strict=False, assign=True)
|
||||
flux_lower.to(dtype=weight_dtype)
|
||||
|
||||
# upper (frozen)
|
||||
logger.info("load state dict for upper")
|
||||
flux_upper.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
logger.info("prepare upper model")
|
||||
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
|
||||
flux_upper.to(accelerator.device, dtype=target_dtype)
|
||||
flux_upper.eval()
|
||||
|
||||
if args.fp8_base:
|
||||
# this is required to run on fp8
|
||||
flux_upper = accelerator.prepare(flux_upper)
|
||||
|
||||
flux_upper.to("cpu")
|
||||
|
||||
self.flux_upper = flux_upper
|
||||
del model # we don't need model anymore
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
logger.info("split model prepared")
|
||||
|
||||
return flux_lower
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
@@ -262,6 +302,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
|
||||
# )
|
||||
|
||||
if not args.split_mode:
|
||||
# normal forward
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
@@ -273,6 +315,38 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||
with accelerator.autocast():
|
||||
# move flux lower to cpu, and then move flux upper to gpu
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
self.flux_upper.to(accelerator.device)
|
||||
|
||||
# upper model does not require grad
|
||||
with torch.no_grad():
|
||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# lower model requires grad
|
||||
intermediate_img.requires_grad_(True)
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
@@ -331,6 +405,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split_mode",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
||||
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
|
||||
@@ -918,3 +918,168 @@ class Flux(nn.Module):
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
|
||||
class FluxUpper(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
self.time_in.enable_gradient_checkpointing()
|
||||
self.vector_in.enable_gradient_checkpointing()
|
||||
self.guidance_in.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.time_in.disable_gradient_checkpointing()
|
||||
self.vector_in.disable_gradient_checkpointing()
|
||||
self.guidance_in.disable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
return img, txt, vec, pe
|
||||
|
||||
|
||||
class FluxLower(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.out_channels = params.in_channels
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
vec: Tensor | None = None,
|
||||
pe: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
@@ -252,6 +252,11 @@ def create_network(
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
if train_blocks is not None:
|
||||
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -264,6 +269,7 @@ def create_network(
|
||||
module_dropout=module_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
@@ -314,7 +320,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
@@ -335,6 +343,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
module_class: Type[object] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_blocks: Optional[str] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -347,6 +356,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -360,7 +370,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -434,9 +446,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# create LoRA for U-Net
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
|
||||
elif self.train_blocks == "single":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
|
||||
elif self.train_blocks == "double":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
||||
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
|
||||
Reference in New Issue
Block a user