add experimental split mode for FLUX

This commit is contained in:
Kohya S
2024-08-13 22:28:39 +09:00
parent 9711c96f96
commit 56d7651f08
4 changed files with 304 additions and 23 deletions

View File

@@ -4,12 +4,22 @@ This repository contains training, generation and utility scripts for Stable Dif
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
Aug 13, 2024:
__Experimental__ A network argument `train_blocks` is added to `lora_flux`. This is to select the target blocks of LoRA from FLUX double blocks and single blocks. Specify like `--network_args "train_blocks=single"`. `all` trains both double blocks and single blocks, `double` trains only double blocks, and `single` trains only single blocks. The default (omission) is `all`.
This argument is available even if `--split_mode` is not specified.
__Experimental__ `--split_mode` option is added to `flux_train_network.py`. This splits FLUX into double blocks and single blocks for training. By enabling gradients only for the single blocks part, memory usage is reduced. When this option is specified, you need to specify `"train_blocks=single"` in the network arguments.
This option enables training with 12GB VRAM GPUs, but the training speed is 2-3 times slower than the default.
Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before. Aug 11, 2024: Fix `--apply_t5_attn_mask` option to work. Please remove and re-generate the latents cache file if you have used the option before.
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI. Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
__Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs. We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below, settings are based on [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit). It will work with 24GB VRAM GPUs.
``` ```
@@ -19,7 +29,13 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t
The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below: The training can be done with 16GB VRAM GPUs with Adafactor optimizer. Please use settings like below:
``` ```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"` --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
```
The training can be done with 12GB VRAM GPUs with Adafactor optimizer, `--split_mode` and `train_blocks=single` options. Please use settings like below:
```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --split_mode --network_args "train_blocks=single"
``` ```
LoRAs for Text Encoders are not tested yet. LoRAs for Text Encoders are not tested yet.

View File

@@ -37,10 +37,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.network_train_unet_only or not args.cache_text_encoder_outputs args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32) train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models # currently offload to cpu for some models
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
# if we load to cpu, flux.to(fp8) takes a long time
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu") clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
clip_l.eval() clip_l.eval()
@@ -49,13 +55,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu") t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
t5xxl.eval() t5xxl.eval()
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
# if we load to cpu, flux.to(fp8) takes a long time
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
def prepare_split_model(self, model, weight_dtype, accelerator):
from accelerate import init_empty_weights
logger.info("prepare split model")
with init_empty_weights():
flux_upper = flux_models.FluxUpper(model.params)
flux_lower = flux_models.FluxLower(model.params)
sd = model.state_dict()
# lower (trainable)
logger.info("load state dict for lower")
flux_lower.load_state_dict(sd, strict=False, assign=True)
flux_lower.to(dtype=weight_dtype)
# upper (frozen)
logger.info("load state dict for upper")
flux_upper.load_state_dict(sd, strict=False, assign=True)
logger.info("prepare upper model")
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
flux_upper.to(accelerator.device, dtype=target_dtype)
flux_upper.eval()
if args.fp8_base:
# this is required to run on fp8
flux_upper = accelerator.prepare(flux_upper)
flux_upper.to("cpu")
self.flux_upper = flux_upper
del model # we don't need model anymore
clean_memory_on_device(accelerator.device)
logger.info("split model prepared")
return flux_lower
def get_tokenize_strategy(self, args): def get_tokenize_strategy(self, args):
return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
@@ -262,6 +302,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}" # f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
# ) # )
if not args.split_mode:
# normal forward
with accelerator.autocast(): with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet( model_pred = unet(
@@ -273,6 +315,38 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps / 1000, timesteps=timesteps / 1000,
guidance=guidance_vec, guidance=guidance_vec,
) )
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe)
# unpack latents # unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
@@ -331,6 +405,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
) )
parser.add_argument(
"--split_mode",
action="store_true",
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
)
# copy from Diffusers # copy from Diffusers
parser.add_argument( parser.add_argument(

View File

@@ -918,3 +918,168 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img return img
class FluxUpper(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.gradient_checkpointing = False
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
self.guidance_in.enable_gradient_checkpointing()
for block in self.double_blocks:
block.enable_gradient_checkpointing()
print("FLUX: Gradient checkpointing enabled.")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
self.guidance_in.disable_gradient_checkpointing()
for block in self.double_blocks:
block.disable_gradient_checkpointing()
print("FLUX: Gradient checkpointing disabled.")
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
return img, txt, vec, pe
class FluxLower(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.out_channels = params.in_channels
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
self.gradient_checkpointing = False
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
for block in self.single_blocks:
block.enable_gradient_checkpointing()
print("FLUX: Gradient checkpointing enabled.")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
for block in self.single_blocks:
block.disable_gradient_checkpointing()
print("FLUX: Gradient checkpointing disabled.")
def forward(
self,
img: Tensor,
txt: Tensor,
vec: Tensor | None = None,
pe: Tensor | None = None,
) -> Tensor:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@@ -252,6 +252,11 @@ def create_network(
if module_dropout is not None: if module_dropout is not None:
module_dropout = float(module_dropout) module_dropout = float(module_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
if train_blocks is not None:
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
# すごく引数が多いな ( ^ω^)・・・ # すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork( network = LoRANetwork(
text_encoders, text_encoders,
@@ -264,6 +269,7 @@ def create_network(
module_dropout=module_dropout, module_dropout=module_dropout,
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,
conv_alpha=conv_alpha, conv_alpha=conv_alpha,
train_blocks=train_blocks,
varbose=True, varbose=True,
) )
@@ -314,7 +320,9 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
@@ -335,6 +343,7 @@ class LoRANetwork(torch.nn.Module):
module_class: Type[object] = LoRAModule, module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None, modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
varbose: Optional[bool] = False, varbose: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -347,6 +356,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout self.dropout = dropout
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.loraplus_lr_ratio = None self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None self.loraplus_unet_lr_ratio = None
@@ -360,7 +370,9 @@ class LoRANetwork(torch.nn.Module):
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
) )
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances # create module instances
def create_modules( def create_modules(
@@ -434,9 +446,17 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped skipped_te += skipped
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# create LoRA for U-Net
if self.train_blocks == "all":
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
elif self.train_blocks == "single":
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
elif self.train_blocks == "double":
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE) self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0: if varbose and len(skipped) > 0: