Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into flux_shift

This commit is contained in:
sdbds
2024-09-05 13:56:13 +08:00
11 changed files with 503 additions and 138 deletions

View File

@@ -11,6 +11,18 @@ The command to install PyTorch is as follows:
### Recent Updates
Sep 5, 2024:
The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.
Sep 4, 2024:
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details.
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.
Sep 1, 2024:
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
Aug 29, 2024:
Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated.
@@ -37,8 +49,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
@@ -68,11 +80,17 @@ The trained LoRA model can be used with ComfyUI.
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
- `--clip_l` is the path to the CLIP-L model.
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
- `--timestep_sampling` is the method to sample timesteps (0-1):
- `sigma`: sigma-based, same as SD3
- `uniform`: uniform random
- `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc.
- `shift`: shifts the value of sigmoid of normal distribution random number
- `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified.
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
- This option is effective even when`--timestep_sampling shift` is specified.
- Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
@@ -109,16 +127,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
#### Key Features for FLUX.1 LoRA training
1. CLIP-L LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L LoRA.
1. CLIP-L and T5XXL LoRA Support:
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
- Remove `--network_train_unet_only` from your command.
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The trained LoRA can be used with ComfyUI.
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|---|---|---|---|
|FLUX.1|`--network_train_unet_only`|-|o|
|FLUX.1 + CLIP-L|-|-|o (*2)|
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- *2: T5XXL output can be cached for CLIP-L LoRA training.
- *3: Not tested yet.
2. Experimental FP8/FP16 mixed training:
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
- When specifying this option, the `--fp8_base` option is automatically enabled.
3. Split Q/K/V Projection Layers (Experimental):
@@ -148,7 +179,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
```
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
```
### FLUX.1 fine-tuning
@@ -159,7 +190,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP
```
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
@@ -179,7 +210,7 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, `
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks.
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details.
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
@@ -193,24 +224,32 @@ The learning rate and the number of epochs are not optimized yet. Please adjust
#### Key Features for FLUX.1 fine-tuning
1. Sample Image Generation:
1. Technical details of double/single block swap:
- Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed.
- During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU.
- The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU.
- Since the transfer between CPU and GPU takes time, the training will be slower.
- `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU.
- About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block.
2. Sample Image Generation:
- Sample image generation during training is now supported.
- The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images.
- Specify options such as `--sample_prompts` and `--sample_every_n_epochs`.
- Note: It will be very slow when `--split_mode` is specified.
2. Experimental Memory-Efficient Saving:
3. Experimental Memory-Efficient Saving:
- `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB).
- This is a custom implementation and may cause unexpected issues. Use with caution.
3. T5XXL Token Length Control:
4. T5XXL Token Length Control:
- Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL.
- Default is 512 in dev and 256 in schnell models.
4. Multi-GPU Training Support:
5. Multi-GPU Training Support:
- Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.
5. Disable mmap Load for Safetensors:
6. Disable mmap Load for Safetensors:
- `--disable_mmap_load_safetensors` option now works in `flux_train.py`.
- Speeds up model loading during training in WSL2.
- Effective in reducing memory usage when loading models during multi-GPU training.
@@ -240,21 +279,32 @@ CLIP-L LoRA is not supported.
### Merge LoRA to FLUX.1 checkpoint
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__
```
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
```
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`):
CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below.
```
--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors
```
FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency.
An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving.
`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory):
- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine.
- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM.
- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'.
- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'.
- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'.
In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`.
`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`.
The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly.
@@ -309,6 +359,9 @@ resolution = [512, 512]
SD3 training is done with `sd3_train.py`.
__Sep 1, 2024__:
- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds!
__Jul 27, 2024__:
- Latents and text encoder outputs caching mechanism is refactored significantly.
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.

View File

@@ -5,7 +5,7 @@ import datetime
import math
import os
import random
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional
import einops
import numpy as np
@@ -13,6 +13,7 @@ import torch
from tqdm import tqdm
from PIL import Image
import accelerate
from transformers import CLIPTextModel
from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
@@ -125,7 +126,7 @@ def do_sample(
def generate_image(
model,
clip_l,
clip_l: CLIPTextModel,
t5xxl,
ae,
prompt: str,
@@ -141,12 +142,13 @@ def generate_image(
# make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=device,
dtype=dtype,
dtype=noise_dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
@@ -166,9 +168,48 @@ def generate_image(
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
with torch.no_grad():
if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype):
clip_l.to(clip_l_dtype)
t5xxl.to(t5xxl_dtype)
if is_fp8(clip_l_dtype):
param_itr = clip_l.parameters()
param_itr.__next__() # skip first
param_2nd = param_itr.__next__()
if param_2nd.dtype != clip_l_dtype:
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if is_fp8(t5xxl_dtype):
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
text_encoder.fp8_prepared = True
t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)
with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
@@ -315,10 +356,10 @@ if __name__ == "__main__":
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
t5xxl.eval()
if is_fp8(clip_l_dtype):
clip_l = accelerator.prepare(clip_l)
if is_fp8(t5xxl_dtype):
t5xxl = accelerator.prepare(t5xxl)
# if is_fp8(clip_l_dtype):
# clip_l = accelerator.prepare(clip_l)
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
@@ -329,14 +370,16 @@ if __name__ == "__main__":
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
if is_fp8(flux_dtype):
model = accelerator.prepare(model)
# if is_fp8(flux_dtype):
# model = accelerator.prepare(model)
# if args.offload:
# model = model.to("cpu")
# AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
ae.eval()
if is_fp8(ae_dtype):
ae = accelerator.prepare(ae)
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)
# LoRA
lora_models: List[lora_flux.LoRANetwork] = []
@@ -360,7 +403,7 @@ if __name__ == "__main__":
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
else:

View File

@@ -651,7 +651,7 @@ def train(args):
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"])
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):

View File

@@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# assert (
# args.network_train_unet_only or not args.cache_text_encoder_outputs
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
if not args.network_train_unet_only:
logger.info(
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
)
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
@@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# currently offload to cpu for some models
name = self.get_flux_model_name(args)
# if we load to cpu, flux.to(fp8) takes a long time
if args.fp8_base:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
@@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
@@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoding_strategy(self, args):
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.is_train_text_encoder(args):
if self.train_clip_l and not self.train_t5xxl:
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
else:
return text_encoders # ignored
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # both CLIP-L and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [True, False] if self.is_train_text_encoder(args) else [False, False]
return [self.train_clip_l, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
is_partial=self.is_train_text_encoder(args),
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
@@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)
if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
@@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
if not args.split_mode:
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
return
@@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
@@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0: # CLIP-L
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()

View File

@@ -85,7 +85,7 @@ def sample_images(
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
@@ -187,14 +187,27 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
te_outputs = sample_prompts_te_outputs[prompt]
else:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -586,8 +599,8 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト",
)
parser.add_argument(
"--sigmoid_scale",

View File

@@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
return clip
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
def load_t5xxl(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
@@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
return t5xxl
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]

View File

@@ -5,8 +5,7 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import sd3_utils, train_util
from library import sd3_models
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
@@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
@@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
logger.warning(
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]

View File

@@ -2,6 +2,7 @@ import argparse
import math
import os
import time
from typing import Any, Dict, Union
import torch
from safetensors import safe_open
@@ -34,11 +35,11 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False):
if dtype is not None:
logger.info(f"converting to {dtype}...")
for key in tqdm(list(state_dict.keys())):
if type(state_dict[key]) == torch.Tensor:
if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point:
state_dict[key] = state_dict[key].to(dtype)
logger.info(f"saving to: {file_name}")
@@ -49,26 +50,76 @@ def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
def merge_to_flux_model(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
loading_device,
working_device,
flux_path: str,
clip_l_path: str,
t5xxl_path: str,
models,
ratios,
merge_dtype,
save_dtype,
mem_eff_load_save=False,
):
# create module map without loading state_dict
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
lora_name_to_module_key = {}
with safe_open(flux_model, framework="pt", device=loading_device) as flux_file:
keys = list(flux_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
lora_name_to_module_key[lora_name] = key
if flux_path is not None:
logger.info(f"loading keys from FLUX.1 model: {flux_path}")
with safe_open(flux_path, framework="pt", device=loading_device) as flux_file:
keys = list(flux_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
lora_name_to_module_key[lora_name] = key
lora_name_to_clip_l_key = {}
if clip_l_path is not None:
logger.info(f"loading keys from clip_l model: {clip_l_path}")
with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file:
keys = list(clip_l_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_")
lora_name_to_clip_l_key[lora_name] = key
lora_name_to_t5xxl_key = {}
if t5xxl_path is not None:
logger.info(f"loading keys from t5xxl model: {t5xxl_path}")
with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file:
keys = list(t5xxl_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_")
lora_name_to_t5xxl_key[lora_name] = key
flux_state_dict = {}
clip_l_state_dict = {}
t5xxl_state_dict = {}
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
if flux_path is not None:
with MemoryEfficientSafeOpen(flux_path) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
if clip_l_path is not None:
with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file:
for key in tqdm(clip_l_file.keys()):
clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device)
if t5xxl_path is not None:
with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file:
for key in tqdm(t5xxl_file.keys()):
t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device)
else:
flux_state_dict = load_file(flux_model, device=loading_device)
if flux_path is not None:
flux_state_dict = load_file(flux_path, device=loading_device)
if clip_l_path is not None:
clip_l_state_dict = load_file(clip_l_path, device=loading_device)
if t5xxl_path is not None:
t5xxl_state_dict = load_file(t5xxl_path, device=loading_device)
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
@@ -81,8 +132,20 @@ def merge_to_flux_model(
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
if lora_name not in lora_name_to_module_key:
logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.")
if lora_name in lora_name_to_module_key:
module_weight_key = lora_name_to_module_key[lora_name]
state_dict = flux_state_dict
elif lora_name in lora_name_to_clip_l_key:
module_weight_key = lora_name_to_clip_l_key[lora_name]
state_dict = clip_l_state_dict
elif lora_name in lora_name_to_t5xxl_key:
module_weight_key = lora_name_to_t5xxl_key[lora_name]
state_dict = t5xxl_state_dict
else:
logger.warning(
f"no module found for LoRA weight: {key}. Skipping..."
f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。"
)
continue
down_weight = lora_sd.pop(key)
@@ -93,11 +156,7 @@ def merge_to_flux_model(
scale = alpha / dim
# W <- W + U * D
module_weight_key = lora_name_to_module_key[lora_name]
if module_weight_key not in flux_state_dict:
weight = flux_file.get_tensor(module_weight_key)
else:
weight = flux_state_dict[module_weight_key]
weight = state_dict[module_weight_key]
weight = weight.to(working_device, merge_dtype)
up_weight = up_weight.to(working_device, merge_dtype)
@@ -121,7 +180,7 @@ def merge_to_flux_model(
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
del up_weight
del down_weight
del weight
@@ -129,7 +188,7 @@ def merge_to_flux_model(
if len(lora_sd) > 0:
logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
return flux_state_dict
return flux_state_dict, clip_l_state_dict, t5xxl_state_dict
def merge_to_flux_model_diffusers(
@@ -508,17 +567,28 @@ def merge(args):
if save_dtype is None:
save_dtype = merge_dtype
dest_dir = os.path.dirname(args.save_to)
assert (
args.save_to or args.clip_l_save_to or args.t5xxl_save_to
), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください"
dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to)
if not os.path.exists(dest_dir):
logger.info(f"creating directory: {dest_dir}")
os.makedirs(dest_dir)
if args.flux_model is not None:
if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None:
if not args.diffusers:
state_dict = merge_to_flux_model(
assert (args.clip_l is None and args.clip_l_save_to is None) or (
args.clip_l is not None and args.clip_l_save_to is not None
), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください"
assert (args.t5xxl is None and args.t5xxl_save_to is None) or (
args.t5xxl is not None and args.t5xxl_save_to is not None
), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください"
flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model(
args.loading_device,
args.working_device,
args.flux_model,
args.clip_l,
args.t5xxl,
args.models,
args.ratios,
merge_dtype,
@@ -526,7 +596,10 @@ def merge(args):
args.mem_eff_load_save,
)
else:
state_dict = merge_to_flux_model_diffusers(
assert (
args.clip_l is None and args.t5xxl is None
), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません"
flux_state_dict = merge_to_flux_model_diffusers(
args.loading_device,
args.working_device,
args.flux_model,
@@ -536,8 +609,10 @@ def merge(args):
save_dtype,
args.mem_eff_load_save,
)
clip_l_state_dict = None
t5xxl_state_dict = None
if args.no_metadata:
if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0):
sai_metadata = None
else:
merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
@@ -546,15 +621,24 @@ def merge(args):
None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
)
logger.info(f"saving FLUX model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
if flux_state_dict is not None and len(flux_state_dict) > 0:
logger.info(f"saving FLUX model to: {args.save_to}")
save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
if clip_l_state_dict is not None and len(clip_l_state_dict) > 0:
logger.info(f"saving clip_l model to: {args.clip_l_save_to}")
save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save)
if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0:
logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}")
save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
logger.info("calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
@@ -562,12 +646,12 @@ def merge(args):
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
)
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, flux_state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -592,6 +676,18 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
)
parser.add_argument(
"--clip_l",
type=str,
default=None,
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--t5xxl",
type=str,
default=None,
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
@@ -617,6 +713,18 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--clip_l_save_to",
type=str,
default=None,
help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--t5xxl_save_to",
type=str,
default=None,
help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--models",
type=str,

View File

@@ -330,6 +330,11 @@ def create_network(
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -344,6 +349,7 @@ def create_network(
conv_alpha=conv_alpha,
train_blocks=train_blocks,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
varbose=True,
)
@@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
train_t5xxl = None
for key, value in weights_sd.items():
if "." not in key:
continue
@@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None:
train_t5xxl = False
# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
@@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
)
return network, weights_sd
@@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module):
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
def __init__(
self,
@@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module):
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
# create module instances
def create_modules(
@@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module):
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# create LoRA for U-Net
if self.train_blocks == "all":

View File

@@ -368,12 +368,32 @@ def train(args):
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
mmdit.requires_grad_(train_mmdit)
if not train_mmdit:
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
if args.num_last_block_to_freeze:
# freeze last n blocks of MM-DIT
block_name = "x_block"
filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name]
accelerator.print(f"filtered_blocks: {len(filtered_blocks)}")
num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze)
accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}")
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
for i in range(start_freezing_from, len(filtered_blocks)):
_, param = filtered_blocks[i]
param.requires_grad = False
training_models = []
params_to_optimize = []
# if train_unet:
training_models.append(mmdit)
# if block_lrs is None:
params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate})
# else:
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
@@ -1026,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
)
parser.add_argument(
"--num_last_block_to_freeze",
type=int,
default=None,
help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する",
)
return parser

View File

@@ -157,6 +157,9 @@ class NetworkTrainer:
# region SD/SDXL
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
@@ -237,6 +240,13 @@ class NetworkTrainer:
def is_text_encoder_not_needed_for_training(self, args):
return False # use for sample images
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
# endregion
def train(self, args):
@@ -329,7 +339,7 @@ class NetworkTrainer:
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group)
self.assert_extra_args(args, train_dataset_group) # may change some args
# acceleratorを準備する
logger.info("preparing accelerator")
@@ -428,12 +438,15 @@ class NetworkTrainer:
)
args.scale_weight_norms = False
self.post_process_network(args, accelerator, network, text_encoders, unet)
# apply network to unet and text_encoder
train_unet = not args.network_train_text_encoder_only
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
# FIXME consider alpha of weights: this assumes that the alpha is not changed
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
@@ -533,7 +546,7 @@ class NetworkTrainer:
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn
if not args.fp8_base_unet:
accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
@@ -545,17 +558,16 @@ class NetworkTrainer:
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
for i, t_enc in enumerate(text_encoders):
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# nn.Embedding not support FP8
if te_weight_dtype != weight_dtype:
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
@@ -596,12 +608,12 @@ class NetworkTrainer:
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if frag:
t_enc.text_model.embeddings.requires_grad_(True)
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
else:
unet.eval()
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
for t_enc in text_encoders:
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
for i, t_enc in enumerate(text_encoders):
params_itr = t_enc.parameters()
params_itr.__next__() # skip the first parameter
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
param_3rd = params_itr.__next__()
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
clean_memory_on_device(accelerator.device)
@@ -1081,15 +1097,11 @@ class NetworkTrainer:
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if (
text_encoder_conds is None
or len(text_encoder_conds) == 0
or text_encoder_conds[0] is None
or train_text_encoder
):
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
@@ -1112,10 +1124,14 @@ class NetworkTrainer:
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(