Add support for specifying rank for each layer in FLUX.1

This commit is contained in:
Kohya S
2024-09-14 22:17:52 +09:00
parent 2d8ee3c280
commit c9ff4de905
2 changed files with 161 additions and 7 deletions

View File

@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
### Recent Updates
Sep 14, 2024:
- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details.
- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details.
Sep 11, 2024:
Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev!
@@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model)
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
- [FLUX.1 OFT training](#flux1-oft-training)
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
@@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/
The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.
#### Specify rank for each layer in FLUX.1
You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.
When network_args is not specified, the default value (`network_dim`) is applied, same as before.
|network_args|target layer|
|---|---|
|img_attn_dim|img_attn in DoubleStreamBlock|
|txt_attn_dim|txt_attn in DoubleStreamBlock|
|img_mlp_dim|img_mlp in DoubleStreamBlock|
|txt_mlp_dim|txt_mlp in DoubleStreamBlock|
|img_mod_dim|img_mod in DoubleStreamBlock|
|txt_mod_dim|txt_mod in DoubleStreamBlock|
|single_dim|linear1 and linear2 in SingleStreamBlock|
|single_mod_dim|modulation in SingleStreamBlock|
example:
```
--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2"
"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2"
```
You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list.
example:
```
--network_args "in_dims=[4,2,2,2,4]"
```
Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`.
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.
### FLUX.1 OFT training
You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.
- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`.
- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc.
- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it.
- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`.
- `--network_args` specifies the hyperparameters of OFT. The following are valid:
- Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention.
Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`).
Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1.
```
--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3
--network_args "enable_all_linear=True" --learning_rate 1e-5
```
The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer.
### Inference for FLUX.1 with LoRA model
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.