mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into flux_shift
This commit is contained in:
95
README.md
95
README.md
@@ -11,6 +11,18 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Sep 5, 2024:
|
||||
The LoRA merge script now supports CLIP-L and T5XXL LoRA. Please specify `--clip_l` and `--t5xxl`. `--clip_l_save_to` and `--t5xxl_save_to` specify the save destination for CLIP-L and T5XXL. See [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) for details.
|
||||
|
||||
Sep 4, 2024:
|
||||
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI. See [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) for details.
|
||||
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
|
||||
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.
|
||||
|
||||
Sep 1, 2024:
|
||||
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
|
||||
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
|
||||
|
||||
Aug 29, 2024:
|
||||
Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `requirements.txt` is updated.
|
||||
|
||||
@@ -37,8 +49,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.
|
||||
|
||||
```
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
|
||||
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
|
||||
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
|
||||
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
|
||||
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
|
||||
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
||||
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
|
||||
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
|
||||
@@ -68,11 +80,17 @@ The trained LoRA model can be used with ComfyUI.
|
||||
|
||||
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
|
||||
|
||||
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
|
||||
- `--clip_l` is the path to the CLIP-L model.
|
||||
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
|
||||
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
|
||||
|
||||
- `--timestep_sampling` is the method to sample timesteps (0-1):
|
||||
- `sigma`: sigma-based, same as SD3
|
||||
- `uniform`: uniform random
|
||||
- `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc.
|
||||
- `shift`: shifts the value of sigmoid of normal distribution random number
|
||||
- `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified.
|
||||
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
|
||||
- This option is effective even when`--timestep_sampling shift` is specified.
|
||||
- Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
|
||||
@@ -109,16 +127,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
|
||||
|
||||
#### Key Features for FLUX.1 LoRA training
|
||||
|
||||
1. CLIP-L LoRA Support:
|
||||
- FLUX.1 LoRA training now supports CLIP-L LoRA.
|
||||
1. CLIP-L and T5XXL LoRA Support:
|
||||
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
|
||||
- Remove `--network_train_unet_only` from your command.
|
||||
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
|
||||
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
|
||||
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||
- The trained LoRA can be used with ComfyUI.
|
||||
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
|
||||
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
|
||||
|
||||
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|
||||
|---|---|---|---|
|
||||
|FLUX.1|`--network_train_unet_only`|-|o|
|
||||
|FLUX.1 + CLIP-L|-|-|o (*2)|
|
||||
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|
||||
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|
||||
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
|
||||
|
||||
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||
- *2: T5XXL output can be cached for CLIP-L LoRA training.
|
||||
- *3: Not tested yet.
|
||||
|
||||
2. Experimental FP8/FP16 mixed training:
|
||||
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
|
||||
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
|
||||
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
|
||||
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
|
||||
- When specifying this option, the `--fp8_base` option is automatically enabled.
|
||||
|
||||
3. Split Q/K/V Projection Layers (Experimental):
|
||||
@@ -148,7 +179,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
|
||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||
|
||||
```
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||
```
|
||||
|
||||
### FLUX.1 fine-tuning
|
||||
@@ -159,7 +190,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP
|
||||
|
||||
```
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
|
||||
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
|
||||
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
|
||||
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
|
||||
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
||||
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
|
||||
@@ -179,7 +210,7 @@ Options are almost the same as LoRA training. The difference is `--full_bf16`, `
|
||||
|
||||
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
|
||||
|
||||
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks.
|
||||
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks. Please see the next chapter for details.
|
||||
|
||||
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
|
||||
|
||||
@@ -193,24 +224,32 @@ The learning rate and the number of epochs are not optimized yet. Please adjust
|
||||
|
||||
#### Key Features for FLUX.1 fine-tuning
|
||||
|
||||
1. Sample Image Generation:
|
||||
1. Technical details of double/single block swap:
|
||||
- Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed.
|
||||
- During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU.
|
||||
- The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU.
|
||||
- Since the transfer between CPU and GPU takes time, the training will be slower.
|
||||
- `--double_blocks_to_swap` and `--single_blocks_to_swap` specify the number of blocks to swap. For example, `--double_blocks_to_swap 6` swaps 6 blocks at each step of training, but the remaining 13 blocks are always on the GPU.
|
||||
- About 640MB of memory can be saved per double block, and about 320MB of memory can be saved per single block.
|
||||
|
||||
2. Sample Image Generation:
|
||||
- Sample image generation during training is now supported.
|
||||
- The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images.
|
||||
- Specify options such as `--sample_prompts` and `--sample_every_n_epochs`.
|
||||
- Note: It will be very slow when `--split_mode` is specified.
|
||||
|
||||
2. Experimental Memory-Efficient Saving:
|
||||
3. Experimental Memory-Efficient Saving:
|
||||
- `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB).
|
||||
- This is a custom implementation and may cause unexpected issues. Use with caution.
|
||||
|
||||
3. T5XXL Token Length Control:
|
||||
4. T5XXL Token Length Control:
|
||||
- Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL.
|
||||
- Default is 512 in dev and 256 in schnell models.
|
||||
|
||||
4. Multi-GPU Training Support:
|
||||
5. Multi-GPU Training Support:
|
||||
- Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training.
|
||||
|
||||
5. Disable mmap Load for Safetensors:
|
||||
6. Disable mmap Load for Safetensors:
|
||||
- `--disable_mmap_load_safetensors` option now works in `flux_train.py`.
|
||||
- Speeds up model loading during training in WSL2.
|
||||
- Effective in reducing memory usage when loading models during multi-GPU training.
|
||||
@@ -240,21 +279,32 @@ CLIP-L LoRA is not supported.
|
||||
|
||||
### Merge LoRA to FLUX.1 checkpoint
|
||||
|
||||
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
|
||||
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__
|
||||
|
||||
```
|
||||
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
|
||||
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
|
||||
```
|
||||
|
||||
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
|
||||
|
||||
`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`):
|
||||
CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below.
|
||||
|
||||
```
|
||||
--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors
|
||||
```
|
||||
|
||||
FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency.
|
||||
|
||||
An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving.
|
||||
|
||||
`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory):
|
||||
|
||||
- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine.
|
||||
- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM.
|
||||
- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cuda' / 'cpu'.
|
||||
- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'.
|
||||
- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'.
|
||||
|
||||
In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`.
|
||||
`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`.
|
||||
|
||||
The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly.
|
||||
|
||||
@@ -309,6 +359,9 @@ resolution = [512, 512]
|
||||
|
||||
SD3 training is done with `sd3_train.py`.
|
||||
|
||||
__Sep 1, 2024__:
|
||||
- `--num_last_block_to_freeze` is added to `sd3_train.py`. This option is to freeze the last n blocks of the MMDiT. See [#1417](https://github.com/kohya-ss/sd-scripts/pull/1417) for details. Thanks to sdbds!
|
||||
|
||||
__Jul 27, 2024__:
|
||||
- Latents and text encoder outputs caching mechanism is refactored significantly.
|
||||
- Existing cache files for SD3 need to be recreated. Please delete the previous cache files.
|
||||
|
||||
@@ -5,7 +5,7 @@ import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import accelerate
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from library import device_utils
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
@@ -125,7 +126,7 @@ def do_sample(
|
||||
|
||||
def generate_image(
|
||||
model,
|
||||
clip_l,
|
||||
clip_l: CLIPTextModel,
|
||||
t5xxl,
|
||||
ae,
|
||||
prompt: str,
|
||||
@@ -141,12 +142,13 @@ def generate_image(
|
||||
# make first noise with packed shape
|
||||
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
||||
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
||||
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
||||
noise = torch.randn(
|
||||
1,
|
||||
packed_latent_height * packed_latent_width,
|
||||
16 * 2 * 2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=noise_dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
@@ -166,9 +168,48 @@ def generate_image(
|
||||
clip_l = clip_l.to(device)
|
||||
t5xxl = t5xxl.to(device)
|
||||
with torch.no_grad():
|
||||
if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype):
|
||||
clip_l.to(clip_l_dtype)
|
||||
t5xxl.to(t5xxl_dtype)
|
||||
if is_fp8(clip_l_dtype):
|
||||
param_itr = clip_l.parameters()
|
||||
param_itr.__next__() # skip first
|
||||
param_2nd = param_itr.__next__()
|
||||
if param_2nd.dtype != clip_l_dtype:
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
||||
clip_l.to(clip_l_dtype) # fp8
|
||||
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
||||
|
||||
with accelerator.autocast():
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
|
||||
if is_fp8(t5xxl_dtype):
|
||||
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
|
||||
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||
hidden_linear = module.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = module.dropout(hidden_states)
|
||||
|
||||
hidden_states = module.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
text_encoder.fp8_prepared = True
|
||||
|
||||
t5xxl.to(t5xxl_dtype)
|
||||
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
||||
|
||||
with accelerator.autocast():
|
||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
@@ -315,10 +356,10 @@ if __name__ == "__main__":
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
||||
t5xxl.eval()
|
||||
|
||||
if is_fp8(clip_l_dtype):
|
||||
clip_l = accelerator.prepare(clip_l)
|
||||
if is_fp8(t5xxl_dtype):
|
||||
t5xxl = accelerator.prepare(t5xxl)
|
||||
# if is_fp8(clip_l_dtype):
|
||||
# clip_l = accelerator.prepare(clip_l)
|
||||
# if is_fp8(t5xxl_dtype):
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
t5xxl_max_length = 256 if is_schnell else 512
|
||||
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
||||
@@ -329,14 +370,16 @@ if __name__ == "__main__":
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
if is_fp8(flux_dtype):
|
||||
model = accelerator.prepare(model)
|
||||
# if is_fp8(flux_dtype):
|
||||
# model = accelerator.prepare(model)
|
||||
# if args.offload:
|
||||
# model = model.to("cpu")
|
||||
|
||||
# AE
|
||||
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
|
||||
ae.eval()
|
||||
if is_fp8(ae_dtype):
|
||||
ae = accelerator.prepare(ae)
|
||||
# if is_fp8(ae_dtype):
|
||||
# ae = accelerator.prepare(ae)
|
||||
|
||||
# LoRA
|
||||
lora_models: List[lora_flux.LoRANetwork] = []
|
||||
@@ -360,7 +403,7 @@ if __name__ == "__main__":
|
||||
lora_model.to(device)
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
|
||||
else:
|
||||
|
||||
@@ -651,7 +651,7 @@ def train(args):
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"])
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
|
||||
@@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# assert (
|
||||
# args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
if not args.network_train_unet_only:
|
||||
logger.info(
|
||||
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
|
||||
)
|
||||
# prepare CLIP-L/T5XXL training flags
|
||||
self.train_clip_l = not args.network_train_unet_only
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
@@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# currently offload to cpu for some models
|
||||
name = self.get_flux_model_name(args)
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
if args.fp8_base:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
model = flux_utils.load_flow_model(
|
||||
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
@@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
|
||||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
t5xxl.eval()
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
# check dtype of model
|
||||
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
||||
elif t5xxl.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 T5XXL model")
|
||||
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
@@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
# check t5xxl is trained or not
|
||||
self.train_t5xxl = network.train_t5xxl
|
||||
|
||||
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
||||
raise ValueError(
|
||||
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
||||
)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if self.is_train_text_encoder(args):
|
||||
if self.train_clip_l and not self.train_t5xxl:
|
||||
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||||
else:
|
||||
return text_encoders # ignored
|
||||
return None # no text encoders are needed for encoding because both are cached
|
||||
else:
|
||||
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [True, False] if self.is_train_text_encoder(args) else [False, False]
|
||||
return [self.train_clip_l, self.train_t5xxl]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
False,
|
||||
is_partial=self.is_train_text_encoder(args),
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
@@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[1].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
||||
|
||||
@@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
@@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
|
||||
if not args.split_mode:
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
return
|
||||
|
||||
@@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
if index == 0: # CLIP-L
|
||||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||
else: # T5XXL
|
||||
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
if index == 0: # CLIP-L
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
else: # T5XXL
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||
hidden_linear = module.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = module.dropout(hidden_states)
|
||||
|
||||
hidden_states = module.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
|
||||
@@ -85,7 +85,7 @@ def sample_images(
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
@@ -187,14 +187,27 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
te_outputs = sample_prompts_te_outputs[prompt]
|
||||
else:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -586,8 +599,8 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
default="sigma",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
|
||||
@@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
|
||||
def load_t5xxl(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> T5EncoderModel:
|
||||
T5_CONFIG_JSON = """
|
||||
{
|
||||
"architectures": [
|
||||
@@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
|
||||
return t5xxl
|
||||
|
||||
|
||||
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
||||
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
||||
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
||||
|
||||
|
||||
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
@@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
@@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
logger.warning(
|
||||
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
||||
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
||||
)
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
@@ -34,11 +35,11 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
|
||||
def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False):
|
||||
if dtype is not None:
|
||||
logger.info(f"converting to {dtype}...")
|
||||
for key in tqdm(list(state_dict.keys())):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
logger.info(f"saving to: {file_name}")
|
||||
@@ -49,26 +50,76 @@ def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
|
||||
|
||||
|
||||
def merge_to_flux_model(
|
||||
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
|
||||
loading_device,
|
||||
working_device,
|
||||
flux_path: str,
|
||||
clip_l_path: str,
|
||||
t5xxl_path: str,
|
||||
models,
|
||||
ratios,
|
||||
merge_dtype,
|
||||
save_dtype,
|
||||
mem_eff_load_save=False,
|
||||
):
|
||||
# create module map without loading state_dict
|
||||
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
|
||||
lora_name_to_module_key = {}
|
||||
with safe_open(flux_model, framework="pt", device=loading_device) as flux_file:
|
||||
keys = list(flux_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_module_key[lora_name] = key
|
||||
if flux_path is not None:
|
||||
logger.info(f"loading keys from FLUX.1 model: {flux_path}")
|
||||
with safe_open(flux_path, framework="pt", device=loading_device) as flux_file:
|
||||
keys = list(flux_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_module_key[lora_name] = key
|
||||
|
||||
lora_name_to_clip_l_key = {}
|
||||
if clip_l_path is not None:
|
||||
logger.info(f"loading keys from clip_l model: {clip_l_path}")
|
||||
with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file:
|
||||
keys = list(clip_l_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_clip_l_key[lora_name] = key
|
||||
|
||||
lora_name_to_t5xxl_key = {}
|
||||
if t5xxl_path is not None:
|
||||
logger.info(f"loading keys from t5xxl model: {t5xxl_path}")
|
||||
with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file:
|
||||
keys = list(t5xxl_file.keys())
|
||||
for key in keys:
|
||||
if key.endswith(".weight"):
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_t5xxl_key[lora_name] = key
|
||||
|
||||
flux_state_dict = {}
|
||||
clip_l_state_dict = {}
|
||||
t5xxl_state_dict = {}
|
||||
if mem_eff_load_save:
|
||||
flux_state_dict = {}
|
||||
with MemoryEfficientSafeOpen(flux_model) as flux_file:
|
||||
for key in tqdm(flux_file.keys()):
|
||||
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
|
||||
if flux_path is not None:
|
||||
with MemoryEfficientSafeOpen(flux_path) as flux_file:
|
||||
for key in tqdm(flux_file.keys()):
|
||||
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
|
||||
|
||||
if clip_l_path is not None:
|
||||
with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file:
|
||||
for key in tqdm(clip_l_file.keys()):
|
||||
clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device)
|
||||
|
||||
if t5xxl_path is not None:
|
||||
with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file:
|
||||
for key in tqdm(t5xxl_file.keys()):
|
||||
t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device)
|
||||
else:
|
||||
flux_state_dict = load_file(flux_model, device=loading_device)
|
||||
if flux_path is not None:
|
||||
flux_state_dict = load_file(flux_path, device=loading_device)
|
||||
if clip_l_path is not None:
|
||||
clip_l_state_dict = load_file(clip_l_path, device=loading_device)
|
||||
if t5xxl_path is not None:
|
||||
t5xxl_state_dict = load_file(t5xxl_path, device=loading_device)
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
@@ -81,8 +132,20 @@ def merge_to_flux_model(
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
|
||||
if lora_name not in lora_name_to_module_key:
|
||||
logger.warning(f"no module found for LoRA weight: {key}. LoRA for Text Encoder is not supported yet.")
|
||||
if lora_name in lora_name_to_module_key:
|
||||
module_weight_key = lora_name_to_module_key[lora_name]
|
||||
state_dict = flux_state_dict
|
||||
elif lora_name in lora_name_to_clip_l_key:
|
||||
module_weight_key = lora_name_to_clip_l_key[lora_name]
|
||||
state_dict = clip_l_state_dict
|
||||
elif lora_name in lora_name_to_t5xxl_key:
|
||||
module_weight_key = lora_name_to_t5xxl_key[lora_name]
|
||||
state_dict = t5xxl_state_dict
|
||||
else:
|
||||
logger.warning(
|
||||
f"no module found for LoRA weight: {key}. Skipping..."
|
||||
f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。"
|
||||
)
|
||||
continue
|
||||
|
||||
down_weight = lora_sd.pop(key)
|
||||
@@ -93,11 +156,7 @@ def merge_to_flux_model(
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
module_weight_key = lora_name_to_module_key[lora_name]
|
||||
if module_weight_key not in flux_state_dict:
|
||||
weight = flux_file.get_tensor(module_weight_key)
|
||||
else:
|
||||
weight = flux_state_dict[module_weight_key]
|
||||
weight = state_dict[module_weight_key]
|
||||
|
||||
weight = weight.to(working_device, merge_dtype)
|
||||
up_weight = up_weight.to(working_device, merge_dtype)
|
||||
@@ -121,7 +180,7 @@ def merge_to_flux_model(
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
flux_state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
|
||||
state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
|
||||
del up_weight
|
||||
del down_weight
|
||||
del weight
|
||||
@@ -129,7 +188,7 @@ def merge_to_flux_model(
|
||||
if len(lora_sd) > 0:
|
||||
logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
|
||||
|
||||
return flux_state_dict
|
||||
return flux_state_dict, clip_l_state_dict, t5xxl_state_dict
|
||||
|
||||
|
||||
def merge_to_flux_model_diffusers(
|
||||
@@ -508,17 +567,28 @@ def merge(args):
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
dest_dir = os.path.dirname(args.save_to)
|
||||
assert (
|
||||
args.save_to or args.clip_l_save_to or args.t5xxl_save_to
|
||||
), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください"
|
||||
dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to)
|
||||
if not os.path.exists(dest_dir):
|
||||
logger.info(f"creating directory: {dest_dir}")
|
||||
os.makedirs(dest_dir)
|
||||
|
||||
if args.flux_model is not None:
|
||||
if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None:
|
||||
if not args.diffusers:
|
||||
state_dict = merge_to_flux_model(
|
||||
assert (args.clip_l is None and args.clip_l_save_to is None) or (
|
||||
args.clip_l is not None and args.clip_l_save_to is not None
|
||||
), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください"
|
||||
assert (args.t5xxl is None and args.t5xxl_save_to is None) or (
|
||||
args.t5xxl is not None and args.t5xxl_save_to is not None
|
||||
), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください"
|
||||
flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model(
|
||||
args.loading_device,
|
||||
args.working_device,
|
||||
args.flux_model,
|
||||
args.clip_l,
|
||||
args.t5xxl,
|
||||
args.models,
|
||||
args.ratios,
|
||||
merge_dtype,
|
||||
@@ -526,7 +596,10 @@ def merge(args):
|
||||
args.mem_eff_load_save,
|
||||
)
|
||||
else:
|
||||
state_dict = merge_to_flux_model_diffusers(
|
||||
assert (
|
||||
args.clip_l is None and args.t5xxl is None
|
||||
), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません"
|
||||
flux_state_dict = merge_to_flux_model_diffusers(
|
||||
args.loading_device,
|
||||
args.working_device,
|
||||
args.flux_model,
|
||||
@@ -536,8 +609,10 @@ def merge(args):
|
||||
save_dtype,
|
||||
args.mem_eff_load_save,
|
||||
)
|
||||
clip_l_state_dict = None
|
||||
t5xxl_state_dict = None
|
||||
|
||||
if args.no_metadata:
|
||||
if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0):
|
||||
sai_metadata = None
|
||||
else:
|
||||
merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
|
||||
@@ -546,15 +621,24 @@ def merge(args):
|
||||
None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
|
||||
)
|
||||
|
||||
logger.info(f"saving FLUX model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
|
||||
if flux_state_dict is not None and len(flux_state_dict) > 0:
|
||||
logger.info(f"saving FLUX model to: {args.save_to}")
|
||||
save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
|
||||
|
||||
if clip_l_state_dict is not None and len(clip_l_state_dict) > 0:
|
||||
logger.info(f"saving clip_l model to: {args.clip_l_save_to}")
|
||||
save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save)
|
||||
|
||||
if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0:
|
||||
logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}")
|
||||
save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save)
|
||||
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
logger.info("calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
@@ -562,12 +646,12 @@ def merge(args):
|
||||
merged_from = sai_model_spec.build_merged_from(args.models)
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
|
||||
flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
save_to_file(args.save_to, flux_state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -592,6 +676,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_l",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem_eff_load_save",
|
||||
action="store_true",
|
||||
@@ -617,6 +713,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_l_save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
|
||||
@@ -330,6 +330,11 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
train_t5xxl = True if train_t5xxl == "True" else False
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -344,6 +349,7 @@ def create_network(
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
@@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# get dim/alpha mapping
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_t5xxl = None
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
@@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
|
||||
if train_t5xxl is None or train_t5xxl is False:
|
||||
train_t5xxl = "lora_te3" in lora_name
|
||||
|
||||
if train_t5xxl is None:
|
||||
train_t5xxl = False
|
||||
|
||||
# # split qkv
|
||||
# double_qkv_rank = None
|
||||
# single_qkv_rank = None
|
||||
@@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
@@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_blocks: Optional[str] = None,
|
||||
split_qkv: bool = False,
|
||||
train_t5xxl: bool = False,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.module_dropout = module_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
# if self.conv_lora_dim is not None:
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
index = i
|
||||
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
break
|
||||
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# create LoRA for U-Net
|
||||
if self.train_blocks == "all":
|
||||
|
||||
29
sd3_train.py
29
sd3_train.py
@@ -368,12 +368,32 @@ def train(args):
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
mmdit.requires_grad_(train_mmdit)
|
||||
if not train_mmdit:
|
||||
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||
|
||||
if args.num_last_block_to_freeze:
|
||||
# freeze last n blocks of MM-DIT
|
||||
block_name = "x_block"
|
||||
filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name]
|
||||
accelerator.print(f"filtered_blocks: {len(filtered_blocks)}")
|
||||
|
||||
num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze)
|
||||
|
||||
accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}")
|
||||
|
||||
start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)
|
||||
|
||||
for i in range(start_freezing_from, len(filtered_blocks)):
|
||||
_, param = filtered_blocks[i]
|
||||
param.requires_grad = False
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
# if train_unet:
|
||||
training_models.append(mmdit)
|
||||
# if block_lrs is None:
|
||||
params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
|
||||
params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate})
|
||||
# else:
|
||||
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))
|
||||
|
||||
@@ -1026,12 +1046,17 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_last_block_to_freeze",
|
||||
type=int,
|
||||
default=None,
|
||||
help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -157,6 +157,9 @@ class NetworkTrainer:
|
||||
|
||||
# region SD/SDXL
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
pass
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
@@ -237,6 +240,13 @@ class NetworkTrainer:
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return False # use for sample images
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -329,7 +339,7 @@ class NetworkTrainer:
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
@@ -428,12 +438,15 @@ class NetworkTrainer:
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
|
||||
self.post_process_network(args, accelerator, network, text_encoders, unet)
|
||||
|
||||
# apply network to unet and text_encoder
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = self.is_train_text_encoder(args)
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
# FIXME consider alpha of weights
|
||||
# FIXME consider alpha of weights: this assumes that the alpha is not changed
|
||||
info = network.load_weights(args.network_weights)
|
||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
@@ -533,7 +546,7 @@ class NetworkTrainer:
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
@@ -545,17 +558,16 @@ class NetworkTrainer:
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||
if t_enc.device.type != "cpu":
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# nn.Embedding not support FP8
|
||||
if te_weight_dtype != weight_dtype:
|
||||
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
@@ -596,12 +608,12 @@ class NetworkTrainer:
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
|
||||
t_enc.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if frag:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
|
||||
|
||||
else:
|
||||
unet.eval()
|
||||
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
|
||||
|
||||
# log device and dtype for each model
|
||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||
for t_enc in text_encoders:
|
||||
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
params_itr = t_enc.parameters()
|
||||
params_itr.__next__() # skip the first parameter
|
||||
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
|
||||
param_3rd = params_itr.__next__()
|
||||
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -1081,15 +1097,11 @@ class NetworkTrainer:
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
if (
|
||||
text_encoder_conds is None
|
||||
or len(text_encoder_conds) == 0
|
||||
or text_encoder_conds[0] is None
|
||||
or train_text_encoder
|
||||
):
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
@@ -1112,10 +1124,14 @@ class NetworkTrainer:
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||
|
||||
Reference in New Issue
Block a user