T5XXL LoRA training, fp8 T5XXL support

This commit is contained in:
Kohya S
2024-09-04 21:33:17 +09:00
parent 6abacf04da
commit b65ae9b439
7 changed files with 222 additions and 67 deletions

View File

@@ -11,6 +11,11 @@ The command to install PyTorch is as follows:
### Recent Updates
Sep 4, 2024:
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI.
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.
Sep 1, 2024:
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
@@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
@@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI.
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
- `--timestep_sampling` is the method to sample timesteps (0-1):
- `sigma`: sigma-based, same as SD3
- `uniform`: uniform random
@@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
#### Key Features for FLUX.1 LoRA training
1. CLIP-L LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L LoRA.
1. CLIP-L and T5XXL LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The trained LoRA can be used with ComfyUI.
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|FLUX.1|`--network_train_unet_only`|-|o|
|FLUX.1 + CLIP-L|-|-|o (*2)|
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L LoRA training.
- *3: Not tested yet.
2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
- When specifying this option, the `--fp8_base` option is automatically enabled.
3. Split Q/K/V Projection Layers (Experimental):
@@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```
### FLUX.1 fine-tuning
@@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
@@ -256,7 +279,7 @@ CLIP-L LoRA is not supported.
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
```
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
```
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.

View File

@@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# assert (
# args.network_train_unet_only or not args.cache_text_encoder_outputs
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
if not args.network_train_unet_only:
logger.info(
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
)
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
@@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# currently offload to cpu for some models
name = self.get_flux_model_name(args)
# if we load to cpu, flux.to(fp8) takes a long time
if args.fp8_base:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
@@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
@@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoding_strategy(self, args):
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.is_train_text_encoder(args):
if self.train_clip_l and not self.train_t5xxl:
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
else:
return text_encoders # ignored
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # both CLIP-L and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [True, False] if self.is_train_text_encoder(args) else [False, False]
return [self.train_clip_l, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
is_partial=self.is_train_text_encoder(args),
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
@@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)
if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
@@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
if not args.split_mode:
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
return
@@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
@@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0: # CLIP-L
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()

View File

@@ -85,7 +85,7 @@ def sample_images(
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
@@ -187,14 +187,27 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
te_outputs = sample_prompts_te_outputs[prompt]
else:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument

View File

@@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
return clip
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
def load_t5xxl(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
@@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
return t5xxl
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]

View File

@@ -5,8 +5,7 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import sd3_utils, train_util
from library import sd3_models
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
@@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
@@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
logger.warning(
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]

View File

@@ -330,6 +330,11 @@ def create_network(
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -344,6 +349,7 @@ def create_network(
conv_alpha=conv_alpha,
train_blocks=train_blocks,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
varbose=True,
)
@@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
train_t5xxl = None
for key, value in weights_sd.items():
if "." not in key:
continue
@@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None:
train_t5xxl = False
# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
@@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
)
return network, weights_sd
@@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module):
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
def __init__(
self,
@@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module):
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
# create module instances
def create_modules(
@@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module):
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# create LoRA for U-Net
if self.train_blocks == "all":

View File

@@ -157,6 +157,9 @@ class NetworkTrainer:
# region SD/SDXL
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
@@ -237,6 +240,13 @@ class NetworkTrainer:
def is_text_encoder_not_needed_for_training(self, args):
return False # use for sample images
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
# endregion
def train(self, args):
@@ -329,7 +339,7 @@ class NetworkTrainer:
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group)
self.assert_extra_args(args, train_dataset_group) # may change some args
# acceleratorを準備する
logger.info("preparing accelerator")
@@ -428,12 +438,15 @@ class NetworkTrainer:
)
args.scale_weight_norms = False
self.post_process_network(args, accelerator, network, text_encoders, unet)
# apply network to unet and text_encoder
train_unet = not args.network_train_text_encoder_only
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
# FIXME consider alpha of weights: this assumes that the alpha is not changed
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
@@ -545,17 +558,16 @@ class NetworkTrainer:
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
for i, t_enc in enumerate(text_encoders):
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# nn.Embedding not support FP8
if te_weight_dtype != weight_dtype:
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
@@ -596,12 +608,12 @@ class NetworkTrainer:
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if frag:
t_enc.text_model.embeddings.requires_grad_(True)
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
else:
unet.eval()
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
for t_enc in text_encoders:
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
for i, t_enc in enumerate(text_encoders):
params_itr = t_enc.parameters()
params_itr.__next__() # skip the first parameter
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
param_3rd = params_itr.__next__()
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
clean_memory_on_device(accelerator.device)
@@ -1085,11 +1101,7 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if (
len(text_encoder_conds) == 0
or text_encoder_conds[0] is None
or train_text_encoder
):
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions: