mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
T5XXL LoRA training, fp8 T5XXL support
This commit is contained in:
45
README.md
45
README.md
@@ -11,6 +11,11 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Sep 4, 2024:
|
||||
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI.
|
||||
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
|
||||
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.
|
||||
|
||||
Sep 1, 2024:
|
||||
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
|
||||
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
|
||||
@@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.
|
||||
|
||||
```
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
|
||||
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
|
||||
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
|
||||
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
|
||||
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
|
||||
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
||||
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
|
||||
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
|
||||
@@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI.
|
||||
|
||||
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
|
||||
|
||||
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
|
||||
- `--clip_l` is the path to the CLIP-L model.
|
||||
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
|
||||
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
|
||||
|
||||
- `--timestep_sampling` is the method to sample timesteps (0-1):
|
||||
- `sigma`: sigma-based, same as SD3
|
||||
- `uniform`: uniform random
|
||||
@@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
|
||||
|
||||
#### Key Features for FLUX.1 LoRA training
|
||||
|
||||
1. CLIP-L LoRA Support:
|
||||
- FLUX.1 LoRA training now supports CLIP-L LoRA.
|
||||
1. CLIP-L and T5XXL LoRA Support:
|
||||
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
|
||||
- Remove `--network_train_unet_only` from your command.
|
||||
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
|
||||
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
|
||||
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||
- The trained LoRA can be used with ComfyUI.
|
||||
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
|
||||
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
|
||||
|
||||
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|
||||
|---|---|---|---|
|
||||
|FLUX.1|`--network_train_unet_only`|-|o|
|
||||
|FLUX.1 + CLIP-L|-|-|o (*2)|
|
||||
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|
||||
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|
||||
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
|
||||
|
||||
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||
- *2: T5XXL output can be cached for CLIP-L LoRA training.
|
||||
- *3: Not tested yet.
|
||||
|
||||
2. Experimental FP8/FP16 mixed training:
|
||||
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
|
||||
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
|
||||
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
|
||||
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
|
||||
- When specifying this option, the `--fp8_base` option is automatically enabled.
|
||||
|
||||
3. Split Q/K/V Projection Layers (Experimental):
|
||||
@@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
|
||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||
|
||||
```
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||
```
|
||||
|
||||
### FLUX.1 fine-tuning
|
||||
@@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP
|
||||
|
||||
```
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
|
||||
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
|
||||
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
|
||||
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
|
||||
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
||||
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
|
||||
@@ -256,7 +279,7 @@ CLIP-L LoRA is not supported.
|
||||
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
|
||||
|
||||
```
|
||||
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
|
||||
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
|
||||
```
|
||||
|
||||
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
|
||||
|
||||
@@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# assert (
|
||||
# args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
if not args.network_train_unet_only:
|
||||
logger.info(
|
||||
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
|
||||
)
|
||||
# prepare CLIP-L/T5XXL training flags
|
||||
self.train_clip_l = not args.network_train_unet_only
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
@@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# currently offload to cpu for some models
|
||||
name = self.get_flux_model_name(args)
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
if args.fp8_base:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
model = flux_utils.load_flow_model(
|
||||
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
@@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
|
||||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
t5xxl.eval()
|
||||
if args.fp8_base and not args.fp8_base_unet:
|
||||
# check dtype of model
|
||||
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
||||
elif t5xxl.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 T5XXL model")
|
||||
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
@@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
# check t5xxl is trained or not
|
||||
self.train_t5xxl = network.train_t5xxl
|
||||
|
||||
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
||||
raise ValueError(
|
||||
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
||||
)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if self.is_train_text_encoder(args):
|
||||
if self.train_clip_l and not self.train_t5xxl:
|
||||
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||||
else:
|
||||
return text_encoders # ignored
|
||||
return None # no text encoders are needed for encoding because both are cached
|
||||
else:
|
||||
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [True, False] if self.is_train_text_encoder(args) else [False, False]
|
||||
return [self.train_clip_l, self.train_t5xxl]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
False,
|
||||
is_partial=self.is_train_text_encoder(args),
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
@@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[1].to(weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
||||
|
||||
@@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
@@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
|
||||
if not args.split_mode:
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
return
|
||||
|
||||
@@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
if index == 0: # CLIP-L
|
||||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||
else: # T5XXL
|
||||
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
if index == 0: # CLIP-L
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
else: # T5XXL
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||
hidden_linear = module.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = module.dropout(hidden_states)
|
||||
|
||||
hidden_states = module.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
|
||||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
|
||||
@@ -85,7 +85,7 @@ def sample_images(
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
@@ -187,14 +187,27 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
te_outputs = sample_prompts_te_outputs[prompt]
|
||||
else:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
|
||||
@@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
|
||||
def load_t5xxl(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> T5EncoderModel:
|
||||
T5_CONFIG_JSON = """
|
||||
{
|
||||
"architectures": [
|
||||
@@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
|
||||
return t5xxl
|
||||
|
||||
|
||||
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
||||
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
||||
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
||||
|
||||
|
||||
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
@@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
@@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
logger.warning(
|
||||
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
||||
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
||||
)
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
|
||||
@@ -330,6 +330,11 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
train_t5xxl = True if train_t5xxl == "True" else False
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -344,6 +349,7 @@ def create_network(
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
@@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# get dim/alpha mapping
|
||||
# get dim/alpha mapping, and train t5xxl
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_t5xxl = None
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
@@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
|
||||
if train_t5xxl is None:
|
||||
train_t5xxl = "lora_te3" in lora_name
|
||||
|
||||
if train_t5xxl is None:
|
||||
train_t5xxl = False
|
||||
|
||||
# # split qkv
|
||||
# double_qkv_rank = None
|
||||
# single_qkv_rank = None
|
||||
@@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
@@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_blocks: Optional[str] = None,
|
||||
split_qkv: bool = False,
|
||||
train_t5xxl: bool = False,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.module_dropout = module_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
# if self.conv_lora_dim is not None:
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
index = i
|
||||
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
break
|
||||
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# create LoRA for U-Net
|
||||
if self.train_blocks == "all":
|
||||
|
||||
@@ -157,6 +157,9 @@ class NetworkTrainer:
|
||||
|
||||
# region SD/SDXL
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
pass
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
@@ -237,6 +240,13 @@ class NetworkTrainer:
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return False # use for sample images
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -329,7 +339,7 @@ class NetworkTrainer:
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
@@ -428,12 +438,15 @@ class NetworkTrainer:
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
|
||||
self.post_process_network(args, accelerator, network, text_encoders, unet)
|
||||
|
||||
# apply network to unet and text_encoder
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = self.is_train_text_encoder(args)
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
# FIXME consider alpha of weights
|
||||
# FIXME consider alpha of weights: this assumes that the alpha is not changed
|
||||
info = network.load_weights(args.network_weights)
|
||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
@@ -533,7 +546,7 @@ class NetworkTrainer:
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
@@ -545,17 +558,16 @@ class NetworkTrainer:
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||
if t_enc.device.type != "cpu":
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# nn.Embedding not support FP8
|
||||
if te_weight_dtype != weight_dtype:
|
||||
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
@@ -596,12 +608,12 @@ class NetworkTrainer:
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
|
||||
t_enc.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if frag:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
|
||||
|
||||
else:
|
||||
unet.eval()
|
||||
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
|
||||
|
||||
# log device and dtype for each model
|
||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||
for t_enc in text_encoders:
|
||||
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
params_itr = t_enc.parameters()
|
||||
params_itr.__next__() # skip the first parameter
|
||||
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
|
||||
param_3rd = params_itr.__next__()
|
||||
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -1085,11 +1101,7 @@ class NetworkTrainer:
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
if (
|
||||
len(text_encoder_conds) == 0
|
||||
or text_encoder_conds[0] is None
|
||||
or train_text_encoder
|
||||
):
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
|
||||
Reference in New Issue
Block a user