mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
T5XXL LoRA training, fp8 T5XXL support
This commit is contained in:
45
README.md
45
README.md
@@ -11,6 +11,11 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 4, 2024:
|
||||||
|
- T5XXL LoRA is supported in LoRA training. Remove `--network_train_unet_only` and add `train_t5xxl=True` to `--network_args`. CLIP-L is also trained at the same time (T5XXL only cannot be trained). The trained model can be used with ComfyUI.
|
||||||
|
- In LoRA training, when `--fp8_base` is specified, you can specify `t5xxl_fp8_e4m3fn.safetensors` as the T5XXL weights. However, it is recommended to use fp16 weights for caching.
|
||||||
|
- Fixed an issue where the training CLIP-L LoRA was not used in sample image generation during LoRA training.
|
||||||
|
|
||||||
Sep 1, 2024:
|
Sep 1, 2024:
|
||||||
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
|
- `--timestamp_sampling` has `flux_shift` option. Thanks to sdbds!
|
||||||
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
|
- This is the same shift as FLUX.1 dev inference, adjusting the timestep sampling depending on the resolution. `--discrete_flow_shift` is ignored when `flux_shift` is specified. It is not verified which is better, `shift` or `flux_shift`.
|
||||||
@@ -41,8 +46,8 @@ Sample command is below. It will work with 24GB VRAM GPUs.
|
|||||||
|
|
||||||
```
|
```
|
||||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
|
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py
|
||||||
--pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
|
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors
|
||||||
--ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
|
--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers
|
||||||
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
||||||
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
|
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
|
||||||
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
|
--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base
|
||||||
@@ -72,6 +77,11 @@ The trained LoRA model can be used with ComfyUI.
|
|||||||
|
|
||||||
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
|
There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome.
|
||||||
|
|
||||||
|
- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format.
|
||||||
|
- `--clip_l` is the path to the CLIP-L model.
|
||||||
|
- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching.
|
||||||
|
- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`).
|
||||||
|
|
||||||
- `--timestep_sampling` is the method to sample timesteps (0-1):
|
- `--timestep_sampling` is the method to sample timesteps (0-1):
|
||||||
- `sigma`: sigma-based, same as SD3
|
- `sigma`: sigma-based, same as SD3
|
||||||
- `uniform`: uniform random
|
- `uniform`: uniform random
|
||||||
@@ -114,16 +124,29 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
|
|||||||
|
|
||||||
#### Key Features for FLUX.1 LoRA training
|
#### Key Features for FLUX.1 LoRA training
|
||||||
|
|
||||||
1. CLIP-L LoRA Support:
|
1. CLIP-L and T5XXL LoRA Support:
|
||||||
- FLUX.1 LoRA training now supports CLIP-L LoRA.
|
- FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training.
|
||||||
- Remove `--network_train_unet_only` from your command.
|
- Remove `--network_train_unet_only` from your command.
|
||||||
- T5XXL is not trained. Its output is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required.
|
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
|
||||||
|
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||||
- The trained LoRA can be used with ComfyUI.
|
- The trained LoRA can be used with ComfyUI.
|
||||||
- Note: `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
|
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
|
||||||
|
|
||||||
|
| trained LoRA|option|network_args|cache_text_encoder_outputs (*1)|
|
||||||
|
|---|---|---|---|
|
||||||
|
|FLUX.1|`--network_train_unet_only`|-|o|
|
||||||
|
|FLUX.1 + CLIP-L|-|-|o (*2)|
|
||||||
|
|FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-|
|
||||||
|
|CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)|
|
||||||
|
|CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-|
|
||||||
|
|
||||||
|
- *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||||
|
- *2: T5XXL output can be cached for CLIP-L LoRA training.
|
||||||
|
- *3: Not tested yet.
|
||||||
|
|
||||||
2. Experimental FP8/FP16 mixed training:
|
2. Experimental FP8/FP16 mixed training:
|
||||||
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L.
|
- `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL.
|
||||||
- FLUX can be trained with fp8, and CLIP-L can be trained with bf16/fp16.
|
- FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16.
|
||||||
- When specifying this option, the `--fp8_base` option is automatically enabled.
|
- When specifying this option, the `--fp8_base` option is automatically enabled.
|
||||||
|
|
||||||
3. Split Q/K/V Projection Layers (Experimental):
|
3. Split Q/K/V Projection Layers (Experimental):
|
||||||
@@ -153,7 +176,7 @@ The compatibility of the saved model (state dict) is ensured by concatenating th
|
|||||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||||
|
|
||||||
```
|
```
|
||||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
### FLUX.1 fine-tuning
|
### FLUX.1 fine-tuning
|
||||||
@@ -164,7 +187,7 @@ Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GP
|
|||||||
|
|
||||||
```
|
```
|
||||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
|
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py
|
||||||
--pretrained_model_name_or_path flux1-dev.sft --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.sft
|
--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors
|
||||||
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
|
--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2
|
||||||
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16
|
||||||
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
|
--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name
|
||||||
@@ -256,7 +279,7 @@ CLIP-L LoRA is not supported.
|
|||||||
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
|
`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint. __The script is experimental.__
|
||||||
|
|
||||||
```
|
```
|
||||||
python networks/flux_merge_lora.py --flux_model flux1-dev.sft --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
|
python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
|
You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`.
|
||||||
|
|||||||
@@ -43,13 +43,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
train_dataset_group.is_text_encoder_output_cacheable()
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
# assert (
|
# prepare CLIP-L/T5XXL training flags
|
||||||
# args.network_train_unet_only or not args.cache_text_encoder_outputs
|
self.train_clip_l = not args.network_train_unet_only
|
||||||
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||||
if not args.network_train_unet_only:
|
|
||||||
logger.info(
|
|
||||||
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||||
@@ -63,12 +59,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# currently offload to cpu for some models
|
# currently offload to cpu for some models
|
||||||
name = self.get_flux_model_name(args)
|
name = self.get_flux_model_name(args)
|
||||||
|
|
||||||
# if we load to cpu, flux.to(fp8) takes a long time
|
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||||
if args.fp8_base:
|
loading_dtype = None if args.fp8_base else weight_dtype
|
||||||
loading_dtype = None # as is
|
|
||||||
else:
|
|
||||||
loading_dtype = weight_dtype
|
|
||||||
|
|
||||||
|
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||||
model = flux_utils.load_flow_model(
|
model = flux_utils.load_flow_model(
|
||||||
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||||
)
|
)
|
||||||
@@ -85,9 +79,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||||
clip_l.eval()
|
clip_l.eval()
|
||||||
|
|
||||||
|
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||||
|
if args.fp8_base and not args.fp8_base_unet:
|
||||||
|
loading_dtype = None # as is
|
||||||
|
else:
|
||||||
|
loading_dtype = weight_dtype
|
||||||
|
|
||||||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||||
t5xxl.eval()
|
t5xxl.eval()
|
||||||
|
if args.fp8_base and not args.fp8_base_unet:
|
||||||
|
# check dtype of model
|
||||||
|
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
||||||
|
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
||||||
|
elif t5xxl.dtype == torch.float8_e4m3fn:
|
||||||
|
logger.info("Loaded fp8 T5XXL model")
|
||||||
|
|
||||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||||
|
|
||||||
@@ -154,25 +160,35 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def get_text_encoding_strategy(self, args):
|
def get_text_encoding_strategy(self, args):
|
||||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||||
|
|
||||||
|
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||||
|
# check t5xxl is trained or not
|
||||||
|
self.train_t5xxl = network.train_t5xxl
|
||||||
|
|
||||||
|
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
||||||
|
raise ValueError(
|
||||||
|
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
||||||
|
)
|
||||||
|
|
||||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
if self.is_train_text_encoder(args):
|
if self.train_clip_l and not self.train_t5xxl:
|
||||||
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||||||
else:
|
else:
|
||||||
return text_encoders # ignored
|
return None # no text encoders are needed for encoding because both are cached
|
||||||
else:
|
else:
|
||||||
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||||||
|
|
||||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||||
return [True, False] if self.is_train_text_encoder(args) else [False, False]
|
return [self.train_clip_l, self.train_t5xxl]
|
||||||
|
|
||||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
|
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||||
args.cache_text_encoder_outputs_to_disk,
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
None,
|
None,
|
||||||
False,
|
False,
|
||||||
is_partial=self.is_train_text_encoder(args),
|
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -193,8 +209,16 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
logger.info("move text encoders to gpu")
|
logger.info("move text encoders to gpu")
|
||||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
text_encoders[1].to(accelerator.device)
|
||||||
|
|
||||||
|
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
||||||
|
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||||
|
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||||
|
else:
|
||||||
|
# otherwise, we need to convert it to target dtype
|
||||||
|
text_encoders[1].to(weight_dtype)
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
||||||
|
|
||||||
@@ -235,7 +259,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
else:
|
else:
|
||||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
text_encoders[1].to(accelerator.device)
|
||||||
|
|
||||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
@@ -255,9 +279,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# return noise_pred
|
# return noise_pred
|
||||||
|
|
||||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||||
|
text_encoders = text_encoder # for compatibility
|
||||||
|
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||||
|
|
||||||
if not args.split_mode:
|
if not args.split_mode:
|
||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs
|
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -281,7 +308,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs
|
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||||
)
|
)
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -421,6 +448,47 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def is_text_encoder_not_needed_for_training(self, args):
|
def is_text_encoder_not_needed_for_training(self, args):
|
||||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||||
|
|
||||||
|
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||||
|
if index == 0: # CLIP-L
|
||||||
|
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||||
|
else: # T5XXL
|
||||||
|
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
||||||
|
|
||||||
|
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||||
|
if index == 0: # CLIP-L
|
||||||
|
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||||
|
text_encoder.to(te_weight_dtype) # fp8
|
||||||
|
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||||
|
else: # T5XXL
|
||||||
|
|
||||||
|
def prepare_fp8(text_encoder, target_dtype):
|
||||||
|
def forward_hook(module):
|
||||||
|
def forward(hidden_states):
|
||||||
|
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||||
|
hidden_linear = module.wi_1(hidden_states)
|
||||||
|
hidden_states = hidden_gelu * hidden_linear
|
||||||
|
hidden_states = module.dropout(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = module.wo(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
for module in text_encoder.modules():
|
||||||
|
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||||
|
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||||
|
module.to(target_dtype)
|
||||||
|
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||||
|
# print("set", module.__class__.__name__, "hooks")
|
||||||
|
module.forward = forward_hook(module)
|
||||||
|
|
||||||
|
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||||||
|
logger.info(f"T5XXL already prepared for fp8")
|
||||||
|
else:
|
||||||
|
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||||
|
text_encoder.to(te_weight_dtype) # fp8
|
||||||
|
prepare_fp8(text_encoder, weight_dtype)
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = train_network.setup_parser()
|
parser = train_network.setup_parser()
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def sample_images(
|
|||||||
|
|
||||||
if distributed_state.num_processes <= 1:
|
if distributed_state.num_processes <= 1:
|
||||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||||
with torch.no_grad():
|
with torch.no_grad(), accelerator.autocast():
|
||||||
for prompt_dict in prompts:
|
for prompt_dict in prompts:
|
||||||
sample_image_inference(
|
sample_image_inference(
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -187,14 +187,27 @@ def sample_image_inference(
|
|||||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
|
|
||||||
|
text_encoder_conds = []
|
||||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||||
te_outputs = sample_prompts_te_outputs[prompt]
|
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||||
else:
|
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||||
|
if text_encoders is not None:
|
||||||
|
print(f"Encoding prompt: {prompt}")
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||||
# strategy has apply_t5_attn_mask option
|
# strategy has apply_t5_attn_mask option
|
||||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||||
|
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
|
||||||
|
|
||||||
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||||
|
if len(text_encoder_conds) == 0:
|
||||||
|
text_encoder_conds = encoded_text_encoder_conds
|
||||||
|
else:
|
||||||
|
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||||
|
for i in range(len(encoded_text_encoder_conds)):
|
||||||
|
if encoded_text_encoder_conds[i] is not None:
|
||||||
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
|
||||||
|
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||||
|
|
||||||
# sample image
|
# sample image
|
||||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||||
|
|||||||
@@ -171,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
|
|||||||
return clip
|
return clip
|
||||||
|
|
||||||
|
|
||||||
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
|
def load_t5xxl(
|
||||||
|
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||||
|
) -> T5EncoderModel:
|
||||||
T5_CONFIG_JSON = """
|
T5_CONFIG_JSON = """
|
||||||
{
|
{
|
||||||
"architectures": [
|
"architectures": [
|
||||||
@@ -217,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
|
|||||||
return t5xxl
|
return t5xxl
|
||||||
|
|
||||||
|
|
||||||
|
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
||||||
|
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
||||||
|
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
||||||
|
|
||||||
|
|
||||||
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||||
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
from library import sd3_utils, train_util
|
from library import flux_utils, train_util
|
||||||
from library import sd3_models
|
|
||||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||||
|
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
@@ -100,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||||
|
|
||||||
|
self.warn_fp8_weights = False
|
||||||
|
|
||||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
|
|
||||||
@@ -144,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
def cache_batch_outputs(
|
def cache_batch_outputs(
|
||||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||||
):
|
):
|
||||||
|
if not self.warn_fp8_weights:
|
||||||
|
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||||
|
logger.warning(
|
||||||
|
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
||||||
|
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
||||||
|
)
|
||||||
|
self.warn_fp8_weights = True
|
||||||
|
|
||||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||||
captions = [info.caption for info in infos]
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
|
|||||||
@@ -330,6 +330,11 @@ def create_network(
|
|||||||
if split_qkv is not None:
|
if split_qkv is not None:
|
||||||
split_qkv = True if split_qkv == "True" else False
|
split_qkv = True if split_qkv == "True" else False
|
||||||
|
|
||||||
|
# train T5XXL
|
||||||
|
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||||
|
if train_t5xxl is not None:
|
||||||
|
train_t5xxl = True if train_t5xxl == "True" else False
|
||||||
|
|
||||||
# すごく引数が多いな ( ^ω^)・・・
|
# すごく引数が多いな ( ^ω^)・・・
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoders,
|
text_encoders,
|
||||||
@@ -344,6 +349,7 @@ def create_network(
|
|||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
train_blocks=train_blocks,
|
train_blocks=train_blocks,
|
||||||
split_qkv=split_qkv,
|
split_qkv=split_qkv,
|
||||||
|
train_t5xxl=train_t5xxl,
|
||||||
varbose=True,
|
varbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -370,9 +376,10 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
else:
|
else:
|
||||||
weights_sd = torch.load(file, map_location="cpu")
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
# get dim/alpha mapping
|
# get dim/alpha mapping, and train t5xxl
|
||||||
modules_dim = {}
|
modules_dim = {}
|
||||||
modules_alpha = {}
|
modules_alpha = {}
|
||||||
|
train_t5xxl = None
|
||||||
for key, value in weights_sd.items():
|
for key, value in weights_sd.items():
|
||||||
if "." not in key:
|
if "." not in key:
|
||||||
continue
|
continue
|
||||||
@@ -385,6 +392,12 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# logger.info(lora_name, value.size(), dim)
|
# logger.info(lora_name, value.size(), dim)
|
||||||
|
|
||||||
|
if train_t5xxl is None:
|
||||||
|
train_t5xxl = "lora_te3" in lora_name
|
||||||
|
|
||||||
|
if train_t5xxl is None:
|
||||||
|
train_t5xxl = False
|
||||||
|
|
||||||
# # split qkv
|
# # split qkv
|
||||||
# double_qkv_rank = None
|
# double_qkv_rank = None
|
||||||
# single_qkv_rank = None
|
# single_qkv_rank = None
|
||||||
@@ -413,6 +426,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
modules_alpha=modules_alpha,
|
modules_alpha=modules_alpha,
|
||||||
module_class=module_class,
|
module_class=module_class,
|
||||||
split_qkv=split_qkv,
|
split_qkv=split_qkv,
|
||||||
|
train_t5xxl=train_t5xxl,
|
||||||
)
|
)
|
||||||
return network, weights_sd
|
return network, weights_sd
|
||||||
|
|
||||||
@@ -421,10 +435,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
|
||||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -443,6 +457,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
modules_alpha: Optional[Dict[str, int]] = None,
|
modules_alpha: Optional[Dict[str, int]] = None,
|
||||||
train_blocks: Optional[str] = None,
|
train_blocks: Optional[str] = None,
|
||||||
split_qkv: bool = False,
|
split_qkv: bool = False,
|
||||||
|
train_t5xxl: bool = False,
|
||||||
varbose: Optional[bool] = False,
|
varbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -457,6 +472,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||||
self.split_qkv = split_qkv
|
self.split_qkv = split_qkv
|
||||||
|
self.train_t5xxl = train_t5xxl
|
||||||
|
|
||||||
self.loraplus_lr_ratio = None
|
self.loraplus_lr_ratio = None
|
||||||
self.loraplus_unet_lr_ratio = None
|
self.loraplus_unet_lr_ratio = None
|
||||||
@@ -469,12 +485,16 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
)
|
)
|
||||||
if self.conv_lora_dim is not None:
|
# if self.conv_lora_dim is not None:
|
||||||
logger.info(
|
# logger.info(
|
||||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||||
)
|
# )
|
||||||
if self.split_qkv:
|
if self.split_qkv:
|
||||||
logger.info(f"split qkv for LoRA")
|
logger.info(f"split qkv for LoRA")
|
||||||
|
if self.train_blocks is not None:
|
||||||
|
logger.info(f"train {self.train_blocks} blocks only")
|
||||||
|
if train_t5xxl:
|
||||||
|
logger.info(f"train T5XXL as well")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -550,12 +570,15 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped_te = []
|
skipped_te = []
|
||||||
for i, text_encoder in enumerate(text_encoders):
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
index = i
|
index = i
|
||||||
|
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||||
|
|
||||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
|
||||||
|
|
||||||
# create LoRA for U-Net
|
# create LoRA for U-Net
|
||||||
if self.train_blocks == "all":
|
if self.train_blocks == "all":
|
||||||
|
|||||||
@@ -157,6 +157,9 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# region SD/SDXL
|
# region SD/SDXL
|
||||||
|
|
||||||
|
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||||
|
pass
|
||||||
|
|
||||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler(
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||||
@@ -237,6 +240,13 @@ class NetworkTrainer:
|
|||||||
def is_text_encoder_not_needed_for_training(self, args):
|
def is_text_encoder_not_needed_for_training(self, args):
|
||||||
return False # use for sample images
|
return False # use for sample images
|
||||||
|
|
||||||
|
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||||
|
# set top parameter requires_grad = True for gradient checkpointing works
|
||||||
|
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||||
|
|
||||||
|
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||||
|
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
@@ -329,7 +339,7 @@ class NetworkTrainer:
|
|||||||
train_dataset_group.is_latent_cacheable()
|
train_dataset_group.is_latent_cacheable()
|
||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
self.assert_extra_args(args, train_dataset_group)
|
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
logger.info("preparing accelerator")
|
logger.info("preparing accelerator")
|
||||||
@@ -428,12 +438,15 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
args.scale_weight_norms = False
|
args.scale_weight_norms = False
|
||||||
|
|
||||||
|
self.post_process_network(args, accelerator, network, text_encoders, unet)
|
||||||
|
|
||||||
|
# apply network to unet and text_encoder
|
||||||
train_unet = not args.network_train_text_encoder_only
|
train_unet = not args.network_train_text_encoder_only
|
||||||
train_text_encoder = self.is_train_text_encoder(args)
|
train_text_encoder = self.is_train_text_encoder(args)
|
||||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
# FIXME consider alpha of weights
|
# FIXME consider alpha of weights: this assumes that the alpha is not changed
|
||||||
info = network.load_weights(args.network_weights)
|
info = network.load_weights(args.network_weights)
|
||||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||||
|
|
||||||
@@ -545,17 +558,16 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(dtype=unet_weight_dtype)
|
unet.to(dtype=unet_weight_dtype)
|
||||||
for t_enc in text_encoders:
|
for i, t_enc in enumerate(text_encoders):
|
||||||
t_enc.requires_grad_(False)
|
t_enc.requires_grad_(False)
|
||||||
|
|
||||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||||
if t_enc.device.type != "cpu":
|
if t_enc.device.type != "cpu":
|
||||||
t_enc.to(dtype=te_weight_dtype)
|
t_enc.to(dtype=te_weight_dtype)
|
||||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
|
||||||
# nn.Embedding not support FP8
|
# nn.Embedding not support FP8
|
||||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
if te_weight_dtype != weight_dtype:
|
||||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
||||||
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
@@ -596,12 +608,12 @@ class NetworkTrainer:
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
# according to TI example in Diffusers, train is required
|
# according to TI example in Diffusers, train is required
|
||||||
unet.train()
|
unet.train()
|
||||||
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
|
||||||
t_enc.train()
|
t_enc.train()
|
||||||
|
|
||||||
# set top parameter requires_grad = True for gradient checkpointing works
|
# set top parameter requires_grad = True for gradient checkpointing works
|
||||||
if frag:
|
if frag:
|
||||||
t_enc.text_model.embeddings.requires_grad_(True)
|
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# log device and dtype for each model
|
# log device and dtype for each model
|
||||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||||
for t_enc in text_encoders:
|
for i, t_enc in enumerate(text_encoders):
|
||||||
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
params_itr = t_enc.parameters()
|
||||||
|
params_itr.__next__() # skip the first parameter
|
||||||
|
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
|
||||||
|
param_3rd = params_itr.__next__()
|
||||||
|
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
|
||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -1085,11 +1101,7 @@ class NetworkTrainer:
|
|||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||||
if (
|
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||||
len(text_encoder_conds) == 0
|
|
||||||
or text_encoder_conds[0] is None
|
|
||||||
or train_text_encoder
|
|
||||||
):
|
|
||||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
|
|||||||
Reference in New Issue
Block a user