T5XXL LoRA training, fp8 T5XXL support

This commit is contained in:
Kohya S
2024-09-04 21:33:17 +09:00
parent 6abacf04da
commit b65ae9b439
7 changed files with 222 additions and 67 deletions

View File

@@ -11,6 +11,11 @@ The command to install PyTorch is as follows:
### Recent Updates
Sep 4, 2024:
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI.
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.
Sep 1, 2024:
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
@@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
@@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI.
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
- `--timestep_sampling` is the method to sample timesteps (0-1):
- `sigma`: sigma-based, same as SD3
- `uniform`: uniform random
@@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
#### Key Features for FLUX.1 LoRA training
1. CLIP-L LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L LoRA.
1. CLIP-L and T5XXL LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The trained LoRA can be used with ComfyUI.
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|FLUX.1|`--network_train_unet_only`|-|o|
|FLUX.1 + CLIP-L|-|-|o (*2)|
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L LoRA training.
- *3: Not tested yet.
2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
- When specifying this option, the `--fp8_base` option is automatically enabled.
3. Split Q/K/V Projection Layers (Experimental):
@@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```
### FLUX.1 fine-tuning
@@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
@@ -256,7 +279,7 @@ CLIP-L LoRA is not supported.
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
```
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
```
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.