diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..2eddedc7 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,48 @@ +name: Test with pytest + +on: + push: + branches: + - main + - dev + - sd3 + pull_request: + branches: + - main + - dev + - sd3 + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.10"] # Python versions to test + pytorch-version: ["2.4.0"] # PyTorch versions to test + + steps: + - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install and update pip, setuptools, wheel + run: | + # Setuptools, wheel for compiling some packages + python -m pip install --upgrade pip setuptools wheel + + - name: Install dependencies + run: | + # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 + pip install -r requirements.txt + + - name: Test with pytest + run: pytest # See pytest.ini for configuration + diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483..f53cda21 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -1,9 +1,11 @@ --- -# yamllint disable rule:line-length name: Typos -on: # yamllint disable-line rule:truthy +on: push: + branches: + - main + - dev pull_request: types: - opened @@ -16,6 +18,9 @@ jobs: steps: - uses: actions/checkout@v4 + with: + # https://woodruffw.github.io/zizmor/audits/#artipacked + persist-credentials: false - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.28.1 diff --git a/README-ja.md b/README-ja.md index 4ae6b233..27cc56c3 100644 --- a/README-ja.md +++ b/README-ja.md @@ -3,6 +3,10 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ [README in English](./README.md) ←更新情報はこちらにあります +開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。 + +FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。 + GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 以下のスクリプトがあります。 diff --git a/README.md b/README.md index a7047a36..bfb22bcf 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,744 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +## FLUX.1 and SD3 training (WIP) + +This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. + +__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ + +The command to install PyTorch is as follows: +`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` + +- [FLUX.1 training](#flux1-training) +- [SD3 training](#sd3-training) + +### Recent Updates + +Dec 15, 2024: + +- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu! + - Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`. + - Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler. + +Dec 7, 2024: + +- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! + + +- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. + +Dec 3, 2024: + +-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training). + +Dec 2, 2024: + +- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details. + - Not fully tested. Feedback is welcome. + - 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution. + - Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required. + - Multi-GPU training is not tested. + +Dec 1, 2024: + +- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris! + - Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available. + +- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO! + +Nov 14, 2024: + +- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM. +- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved. +- There may be bugs due to the significant changes. Feedback is welcome. + +## FLUX.1 training + +- [FLUX.1 LoRA training](#flux1-lora-training) + - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) + - [Distribution of timesteps](#distribution-of-timesteps) + - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) + - [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) + - [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) +- [FLUX.1 ControlNet training](#flux1-controlnet-training) +- [FLUX.1 OFT training](#flux1-oft-training) +- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model) +- [FLUX.1 fine-tuning](#flux1-fine-tuning) + - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) +- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) +- [Convert FLUX LoRA](#convert-flux-lora) +- [Merge LoRA to FLUX.1 checkpoint](#merge-lora-to-flux1-checkpoint) +- [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) +- [Convert Diffusers to FLUX.1](#convert-diffusers-to-flux1) + +### FLUX.1 LoRA training + +We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. + +FLUX.1 model, CLIP-L, and T5XXL models are recommended to be in bf16/fp16 format. If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. + +Sample command is below. It will work with 24GB VRAM GPUs. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--ae ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_flux --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name flux-lora-name +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 +``` +(The command is multi-line for readability. Please combine it into one line.) + +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +When training LoRA for Text Encoder (without `--network_train_unet_only`), more VRAM is required. Please refer to the settings below to reduce VRAM usage. + +__Options for GPUs with less VRAM:__ + +By specifying `--blocks_to_swap`, you can save VRAM by swapping some blocks between CPU and GPU. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + +Specify a number like `--blocks_to_swap 10`. A larger number will swap more blocks, saving more VRAM, but training will be slower. In FLUX.1, you can swap up to 35 blocks. + +`--cpu_offload_checkpointing` offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--blocks_to_swap`. + +Adafactor optimizer may reduce the VRAM usage than 8bit AdamW. Please use settings like below: + +``` +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 +``` + +The training can be done with 16GB VRAM GPUs with the batch size of 1. Please change your dataset configuration. + +The training can be done with 12GB VRAM GPUs with `--blocks_to_swap 16` with 8bit AdamW. Please use settings like below: + +``` +--blocks_to_swap 16 +``` + +For GPUs with less than 10GB of VRAM, it is recommended to use an fp8 checkpoint for T5XXL. You can download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (please use without `scaled`). + +10GB VRAM GPUs will work with 22 blocks swapped, and 8GB VRAM GPUs will work with 28 blocks swapped. + +__`--split_mode` is deprecated. This option is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. If this option is specified and `--blocks_to_swap` is not specified, `--blocks_to_swap 18` is automatically enabled.__ + +#### Key Options for FLUX.1 LoRA training + +There are many unknown points in FLUX.1 training, so some settings can be specified by arguments. Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. + +- `--pretrained_model_name_or_path` is the path to the pretrained model (FLUX.1). bf16 (original BFL model) is recommended (`flux1-dev.safetensors` or `flux1-dev.sft`). If you specify `--fp8_base`, you can use fp8 models for FLUX.1. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--ae` is the path to the autoencoder model (`ae.safetensors` or `ae.sft`). + +- `--timestep_sampling` is the method to sample timesteps (0-1): + - `sigma`: sigma-based, same as SD3 + - `uniform`: uniform random + - `sigmoid`: sigmoid of random normal, same as x-flux, AI-toolkit etc. + - `shift`: shifts the value of sigmoid of normal distribution random number + - `flux_shift`: shifts the value of sigmoid of normal distribution random number, depending on the resolution (same as FLUX.1 dev inference). `--discrete_flow_shift` is ignored when `flux_shift` is specified. +- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform. + - This option is effective even when`--timestep_sampling shift` is specified. + - Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. +- `--model_prediction_type` is how to interpret and process the model prediction: + - `raw`: use as is, same as x-flux + - `additive`: add to noisy input + - `sigma_scaled`: apply sigma scaling, same as SD3 +- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3). +- `--blocks_to_swap`. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. + +The existing `--loss_type` option may be useful for FLUX.1 training. The default is `l2`. + +~~In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted.~~ + +In our experiments, `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type) seems to work better. + +The settings in [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit) seems to be equivalent to `--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0` (with the default `l2` loss_type). + +Other settings may work better, so please try different settings. + +Other options are described below. + +#### Distribution of timesteps + +`--timestep_sampling` and `--sigmoid_scale`, `--discrete_flow_shift` adjust the distribution of timesteps. The distribution is shown in the figures below. + +The effect of `--discrete_flow_shift` with `--timestep_sampling shift` (when `--sigmoid_scale` is not specified, the default is 1.0): +![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6) + +The difference between `--timestep_sampling sigmoid` and `--timestep_sampling uniform` (when `--timestep_sampling sigmoid` or `uniform` is specified, `--discrete_flow_shift` is ignored): +![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad) + +The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--timestep_sampling sigmoid` is specified, `--discrete_flow_shift` is ignored): +![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc) + +#### Key Features for FLUX.1 LoRA training + +1. CLIP-L and T5XXL LoRA Support: + - FLUX.1 LoRA training now supports CLIP-L and T5XXL LoRA training. + - Remove `--network_train_unet_only` from your command. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. + - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. The first value is the learning rate for CLIP-L, and the second value is for T5XXL. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The trained LoRA can be used with ComfyUI. + - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |FLUX.1|`--network_train_unet_only`|-|o| + |FLUX.1 + CLIP-L|-|-|o (*2)| + |FLUX.1 + CLIP-L + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L LoRA training. + - *3: Not tested yet. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for FLUX and bf16/fp16 for CLIP-L/T5XXL. + - FLUX can be trained with fp8, and CLIP-L/T5XXL can be trained with bf16/fp16. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them. + - Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available). + - May increase expressiveness but also training time. + - The trained model is compatible with normal LoRA models in sd-scripts and can be used in environments like ComfyUI. + - Converting to AI-toolkit (Diffusers) format with `convert_flux_lora.py` will reduce the size. + +4. T5 Attention Mask Application: + - T5 attention mask is applied when `--apply_t5_attn_mask` is specified. + - Now applies mask when encoding T5 and in the attention of Double and Single Blocks + - Affects fine-tuning, LoRA training, and inference in `flux_minimal_inference.py`. + +5. Multi-resolution Training Support: + - FLUX.1 now supports multi-resolution training, even with caching latents to disk. + + +Technical details of Q/K/V split: + +In the implementation of Black Forest Labs' model, the projection layers of q/k/v (and txt in single blocks) are concatenated into one. If LoRA is added there as it is, the LoRA module is only one, and the dimension is large. In contrast, in the implementation of Diffusers, the projection layers of q/k/v/txt are separated. Therefore, the LoRA module is applied to q/k/v/txt separately, and the dimension is smaller. This option is for training LoRA similar to the latter. + +The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large. + +#### Specify rank for each layer in FLUX.1 + +You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|img_attn_dim|img_attn in DoubleStreamBlock| +|txt_attn_dim|txt_attn in DoubleStreamBlock| +|img_mlp_dim|img_mlp in DoubleStreamBlock| +|txt_mlp_dim|txt_mlp in DoubleStreamBlock| +|img_mod_dim|img_mod in DoubleStreamBlock| +|txt_mod_dim|txt_mod in DoubleStreamBlock| +|single_dim|linear1 and linear2 in SingleStreamBlock| +|single_mod_dim|modulation in SingleStreamBlock| + +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + +example: +``` +--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" +"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" "verbose=True" +``` + +You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "in_dims=[4,2,2,2,4]" +``` + +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. + +#### Specify blocks to train in FLUX.1 LoRA training + +You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" +``` + +``` +--network_args "train_double_block_indices=none" "train_single_block_indices=10-15" +``` + +If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual. + +### FLUX.1 ControlNet training +We have added a new training script for ControlNet training. The script is flux_train_control_net.py. See --help for options. + +Sample command is below. It will work with 80GB VRAM GPUs. +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors +--ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 +--optimizer_type adamw8bit --learning_rate 2e-5 +--highvram --max_train_epochs 1 --save_every_n_steps 1000 --dataset_config dataset.toml +--output_dir /path/to/output/dir --output_name flux-cn +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed +``` + +For 24GB VRAM GPUs, you can train with 16 blocks swapped and caching latents and text encoder outputs with the batch size of 1. Remove `--deepspeed` . Sample command is below. Not fully tested. +``` + --blocks_to_swap 16 --cache_latents_to_disk --cache_text_encoder_outputs_to_disk +``` + +The training can be done with 16GB VRAM GPUs with around 30 blocks swapped. + +`--gradient_accumulation_steps` is also available. The default value is 1 (no accumulation), but according to the original PR, 8 is used. + +### FLUX.1 OFT training + +You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. + +- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`. +- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc. +- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it. +- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`. +- `--network_args` specifies the hyperparameters of OFT. The following are valid: + - Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention. + +Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`). + +Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1. + +``` +--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3 +--network_args "enable_all_linear=True" --learning_rate 1e-5 +``` + +The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer. + +### Inference for FLUX.1 with LoRA model + +The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. + +``` +python flux_minimal_inference.py --ckpt flux1-dev.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 +``` + +### FLUX.1 fine-tuning + +The memory-efficient training with block swap is based on 2kpr's implementation. Thanks to 2kpr! + +__`--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. These options is still available, but they will be removed in the future. Please use `--blocks_to_swap` instead. These options are equivalent to specifying `double_blocks_to_swap + single_blocks_to_swap // 2` in `--blocks_to_swap`.__ + +Sample command for FLUX.1 fine-tuning is below. This will work with 24GB VRAM GPUs, and 64GB main memory is recommended. + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train.py +--pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae_dev.safetensors +--save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 +--seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--dataset_config dataset_1024_bs1.toml --output_dir path/to/output/dir --output_name output-name +--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1 +--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" +--lr_scheduler constant_with_warmup --max_grad_norm 0.0 +--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 +--fused_backward_pass --blocks_to_swap 8 --full_bf16 +``` +(The command is multi-line for readability. Please combine it into one line.) + +Options are almost the same as LoRA training. The difference is `--full_bf16`, `--fused_backward_pass` and `--blocks_to_swap`. `--cpu_offload_checkpointing` is also available. + +`--full_bf16` enables the training with bf16 (weights and gradients). + +`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified. + +`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency and stochastic rounding. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now. + +`--blocks_to_swap` is the number of blocks to swap. The default is None (no swap). The maximum value is 35. + +`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage. This option cannot be used with `--blocks_to_swap`. + +All these options are experimental and may change in the future. + +The increasing the number of blocks to swap may reduce the memory usage, but the training speed will be slower. `--cpu_offload_checkpointing` also slows down the training. + +Swap 8 blocks without cpu offload checkpointing may be a good starting point for 24GB VRAM GPUs. Please try different settings according to VRAM usage and training speed. + +The learning rate and the number of epochs are not optimized yet. Please adjust them according to the training results. + +#### How to use block swap + +There are two possible ways to use block swap. It is unknown which is better. + +1. Swap the minimum number of blocks that fit in VRAM with batch size 1 and shorten the training speed of one step. + + The above command example is for this usage. + +2. Swap many blocks to increase the batch size and shorten the training speed per data. + + For example, swapping 35 blocks seems to increase the batch size to about 5. In this case, the training speed per data will be relatively faster than 1. + +#### Training with <24GB VRAM GPUs + +Swap 28 blocks without cpu offload checkpointing may be working with 12GB VRAM GPUs. Please try different settings according to VRAM size of your GPU. + +T5XXL requires about 10GB of VRAM, so 10GB of VRAM will be minimum requirement for FLUX.1 fine-tuning. + +#### Key Features for FLUX.1 fine-tuning + +1. Technical details of block swap: + - Reduce memory usage by transferring double and single blocks of FLUX.1 from GPU to CPU when they are not needed. + - During forward pass, the weights of the blocks that have finished calculation are transferred to CPU, and the weights of the blocks to be calculated are transferred to GPU. + - The same is true for the backward pass, but the order is reversed. The gradients remain on the GPU. + - Since the transfer between CPU and GPU takes time, the training will be slower. + - `--blocks_to_swap` specify the number of blocks to swap. + - About 640MB of memory can be saved per block. + - (Update 1: Nov 12, 2024) + - The maximum number of blocks that can be swapped is 35. + - We are exchanging only the data of the weights (weight.data) in reference to the implementation of OneTrainer (thanks to OneTrainer). However, the mechanism of the exchange is a custom implementation. + - Since it takes time to free CUDA memory (torch.cuda.empty_cache()), we reuse the CUDA memory allocated to weight.data as it is and exchange the weights between modules. + - This shortens the time it takes to exchange weights between modules. + - Since the weights must be almost identical to be exchanged, FLUX.1 exchanges the weights between double blocks and single blocks. + - In SD3, all blocks are similar, but some weights are different, so there are weights that always remain on the GPU. + +2. Sample Image Generation: + - Sample image generation during training is now supported. + - The prompts are cached and used for generation if `--cache_latents` is specified. So changing the prompts during training will not affect the generated images. + - Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. + - Note: It will be very slow when `--blocks_to_swap` is specified. + +3. Experimental Memory-Efficient Saving: + - `--mem_eff_save` option can further reduce memory consumption during model saving (about 22GB). + - This is a custom implementation and may cause unexpected issues. Use with caution. + +4. T5XXL Token Length Control: + - Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. + - Default is 512 in dev and 256 in schnell models. + +5. Multi-GPU Training Support: + - Note: `--double_blocks_to_swap` and `--single_blocks_to_swap` cannot be used in multi-GPU training. + +6. Disable mmap Load for Safetensors: + - `--disable_mmap_load_safetensors` option now works in `flux_train.py`. + - Speeds up model loading during training in WSL2. + - Effective in reducing memory usage when loading models during multi-GPU training. + + +### Extract LoRA from FLUX.1 Models + +Script: `networks/flux_extract_lora.py` + +Extracts LoRA from the difference between two FLUX.1 models. + +Offers memory-efficient option with `--mem_eff_safe_open`. + +CLIP-L LoRA is not supported. + +### Convert FLUX LoRA + +Script: `convert_flux_lora.py` + +Converts LoRA between sd-scripts format (BFL-based) and AI-toolkit format (Diffusers-based). + +If you use LoRA in the inference environment, converting it to AI-toolkit format may reduce temporary memory usage. + +Note that re-conversion will increase the size of LoRA. + +CLIP-L/T5XXL LoRA is not supported. + +### Merge LoRA to FLUX.1 checkpoint + +`networks/flux_merge_lora.py` merges LoRA to FLUX.1 checkpoint, CLIP-L or T5XXL models. __The script is experimental.__ + +``` +python networks/flux_merge_lora.py --flux_model flux1-dev.safetensors --save_to output.safetensors --models lora1.safetensors --ratios 2.0 --save_precision fp16 --loading_device cuda --working_device cpu +``` + +You can also merge multiple LoRA models into a FLUX.1 model. Specify multiple LoRA models in `--models`. Specify the same number of ratios in `--ratios`. + +CLIP-L and T5XXL LoRA are supported. `--clip_l` and `--clip_l_save_to` are for CLIP-L, `--t5xxl` and `--t5xxl_save_to` are for T5XXL. Sample command is below. + +``` +--clip_l clip_l.safetensors --clip_l_save_to merged_clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --t5xxl_save_to merged_t5xxl.safetensors +``` + +FLUX.1, CLIP-L, and T5XXL can be merged together or separately for memory efficiency. + +An experimental option `--mem_eff_load_save` is available. This option is for memory-efficient loading and saving. It may also speed up loading and saving. + +`--loading_device` is the device to load the LoRA models. `--working_device` is the device to merge (calculate) the models. Default is `cpu` for both. Loading / working device examples are below (in the case of `--save_precision fp16` or `--save_precision bf16`, `float32` will consume more memory): + +- 'cpu' / 'cpu': Uses >50GB of RAM, but works on any machine. +- 'cuda' / 'cpu': Uses 24GB of VRAM, but requires 30GB of RAM. +- 'cpu' / 'cuda': Uses 4GB of VRAM, but requires 50GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. +- 'cuda' / 'cuda': Uses 30GB of VRAM, but requires 30GB of RAM, faster than 'cpu' / 'cpu' or 'cuda' / 'cpu'. + +`--save_precision` is the precision to save the merged model. In the case of LoRA models are trained with `bf16`, we are not sure which is better, `fp16` or `bf16` for `--save_precision`. + +The script can merge multiple LoRA models. If you want to merge multiple LoRA models, specify `--concat` option to work the merged LoRA model properly. + +### FLUX.1 Multi-resolution training + +You can define multiple resolutions in the dataset configuration file. + +The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. + +``` +[general] +# define common settings here +flip_aug = true +color_aug = false +keep_tokens_separator= "|||" +shuffle_caption = false +caption_tag_dropout_rate = 0 +caption_extension = ".txt" + +[[datasets]] +# define the first resolution here +batch_size = 2 +enable_bucket = true +resolution = [1024, 1024] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the second resolution here +batch_size = 3 +enable_bucket = true +resolution = [768, 768] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 + +[[datasets]] +# define the third resolution here +batch_size = 4 +enable_bucket = true +resolution = [512, 512] + + [[datasets.subsets]] + image_dir = "path/to/image/dir" + num_repeats = 1 +``` + +### Convert Diffusers to FLUX.1 + +Script: `convert_diffusers_to_flux1.py` + +Converts Diffusers models to FLUX.1 models. The script is experimental. See `--help` for options. schnell and dev models are supported. AE/CLIP/T5XXL are not supported. The diffusers folder is a parent folder of `rmer` folder. + +``` +python tools/convert_diffusers_to_flux.py --diffusers_path path/to/diffusers_folder_or_00001_safetensors --save_to path/to/flux1.safetensors --mem_eff_load_save --save_precision bf16 +``` + +## SD3 training + +SD3.5L/M training is now available. + +### SD3 LoRA training + +The script is `sd3_train_network.py`. See `--help` for options. + +SD3 model, CLIP-L, CLIP-G, and T5XXL models are recommended to be in float/fp16 format. If you specify `--fp8_base`, you can use fp8 models for SD3. The fp8 model is only compatible with `float8_e4m3fn` format. + +Sample command is below. It will work with 16GB VRAM GPUs (SD3.5L). + +``` +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 sd3_train_network.py +--pretrained_model_name_or_path path/to/sd3.5_large.safetensors --clip_l sd3/clip_l.safetensors --clip_g sd3/clip_g.safetensors --t5xxl sd3/t5xxl_fp16.safetensors +--cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers +--max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 +--network_module networks.lora_sd3 --network_dim 4 --network_train_unet_only +--optimizer_type adamw8bit --learning_rate 1e-4 +--cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base +--highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml +--output_dir path/to/output/dir --output_name sd3-lora-name +``` +(The command is multi-line for readability. Please combine it into one line.) + +Like FLUX.1 training, the `--blocks_to_swap` option for memory reduction is available. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M. + +Adafactor optimizer is also available. + +`--cpu_offload_checkpointing` option is not available. + +We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted. + +The trained LoRA model can be used with ComfyUI. + +#### Key Options for SD3 LoRA training + +Here are the arguments. The arguments and sample settings are still experimental and may change in the future. Feedback on the settings is welcome. + +- `--network_module` is the module for LoRA training. Specify `networks.lora_sd3` for SD3 LoRA training. +- `--pretrained_model_name_or_path` is the path to the pretrained model (SD3/3.5). If you specify `--fp8_base`, you can use fp8 models for SD3/3.5. The fp8 model is only compatible with `float8_e4m3fn` format. +- `--clip_l` is the path to the CLIP-L model. +- `--clip_g` is the path to the CLIP-G model. +- `--t5xxl` is the path to the T5XXL model. If you specify `--fp8_base`, you can use fp8 (float8_e4m3fn) models for T5XXL. However, it is recommended to use fp16 models for caching. +- `--vae` is the path to the autoencoder model. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is to disable memory mapping when loading safetensors. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0. +- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training. +- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below. +- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled. + +Other options are described below. + +#### Key Features for SD3 LoRA training + +1. CLIP-L, G and T5XXL LoRA Support: + - SD3 LoRA training now supports CLIP-L, CLIP-G and T5XXL LoRA training. + - Remove `--network_train_unet_only` from your command. + - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L and G is also trained at the same time. + - T5XXL output can be cached for CLIP-L and G LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5 5e-6`. The first value is the learning rate for CLIP-L, the second value is for CLIP-G, and the third value is for T5XXL. If you specify only one, the learning rates for CLIP-L, CLIP-G and T5XXL will be the same. If the third value is not specified, the second value is used for T5XXL. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. + - The trained LoRA can be used with ComfyUI. + + | trained LoRA|option|network_args|cache_text_encoder_outputs (*1)| + |---|---|---|---| + |MMDiT|`--network_train_unet_only`|-|o| + |MMDiT + CLIP-L + CLIP-G|-|-|o (*2)| + |MMDiT + CLIP-L + CLIP-G + T5XXL|-|`train_t5xxl=True`|-| + |CLIP-L + CLIP-G (*3)|`--network_train_text_encoder_only`|-|o (*2)| + |CLIP-L + CLIP-G + T5XXL (*3)|`--network_train_text_encoder_only`|`train_t5xxl=True`|-| + + - *1: `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - *2: T5XXL output can be cached for CLIP-L and G LoRA training. + - *3: Not tested yet. + +2. Experimental FP8/FP16 mixed training: + - `--fp8_base_unet` enables training with fp8 for MMDiT and bf16/fp16 for CLIP-L/G/T5XXL. + - When specifying this option, the `--fp8_base` option is automatically enabled. + +3. Split Q/K/V Projection Layers (Experimental): + - Same as FLUX.1. + +4. CLIP-L/G and T5 Attention Mask Application: + - This function is planned to be implemented in the future. + +5. Multi-resolution Training Support: + - Only for SD3.5M. + - Same as FLUX.1 for data preparation. + - If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__ + +6. Weighting scheme and training shift: + - The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1). + - The uniform distribution is the default. If you want to change the distribution, see `--help` for options. + - `--training_shift` is the shift value for the training distribution of timesteps. + - The effect of a shift in uniform distribution is shown in the figure below. + - ![Figure_1](https://github.com/user-attachments/assets/99a72c67-adfb-4440-81d4-a718985ff350) + +Technical details of multi-resolution training for SD3.5M: + +SD3.5M does not use scaled positional embeddings for multi-resolution training, and is trained with a single positional embedding. Therefore, this feature is very experimental. + +Generally, in multi-resolution training, the values of the positional embeddings must be the same for each resolution. That is, the same value must be in the same position for 512x512, 768x768, and 1024x1024. To achieve this, the positional embeddings for each resolution are calculated in advance and switched according to the resolution of the training data. This feature is enabled by `--enable_scaled_pos_embed`. + +This idea and the code for calculating scaled positional embeddings are contributed by KohakuBlueleaf. Thanks to KohakuBlueleaf! + + +#### Specify rank for each layer in SD3 LoRA + +You can specify the rank for each layer in SD3 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer. + +When network_args is not specified, the default value (`network_dim`) is applied, same as before. + +|network_args|target layer| +|---|---| +|context_attn_dim|attn in context_block| +|context_mlp_dim|mlp in context_block| +|context_mod_dim|adaLN_modulation in context_block| +|x_attn_dim|attn in x_block| +|x_mlp_dim|mlp in x_block| +|x_mod_dim|adaLN_modulation in x_block| + +`"verbose=True"` is also available for debugging. It shows the rank of each layer. + +example: +``` +--network_args "context_attn_dim=2" "context_mlp_dim=3" "context_mod_dim=4" "x_attn_dim=5" "x_mlp_dim=6" "x_mod_dim=7" "verbose=True" +``` + +You can apply LoRA to the conditioning layers of SD3 by specifying `emb_dims` in network_args. When specifying, be sure to specify 6 numbers in `[]` as a comma-separated list. + +example: +``` +--network_args "emb_dims=[2,3,4,5,6,7]" +``` + +Each number corresponds to `context_embedder`, `t_embedder`, `x_embedder`, `y_embedder`, `final_layer_adaLN_modulation`, `final_layer_linear`. The above example applies LoRA to all conditioning layers, with rank 2 for `context_embedder`, 3 for `t_embedder`, 4 for `context_embedder`, 5 for `y_embedder`, 6 for `final_layer_adaLN_modulation`, and 7 for `final_layer_linear`. + +If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,4,0,0]` applies LoRA only to `context_embedder` and `y_embedder`. + +#### Specify blocks to train in SD3 LoRA training + +You can specify the blocks to train in SD3 LoRA training by specifying `train_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. + +The number of blocks depends on the model. The valid range is 0-(the number of blocks - 1). `all` is also available to train all blocks, `none` is also available to train no blocks. + +example: +``` +--network_args "train_block_indices=1,2,6-8" +``` + +### Inference for SD3 with LoRA model + +The inference script is also available. The script is `sd3_minimal_inference.py`. See `--help` for options. + +### SD3 fine-tuning + +Documentation is not available yet. Please refer to the FLUX.1 fine-tuning guide for now. The major difference are following: + +- `--clip_g` is also available for SD3 fine-tuning. +- `--timestep_sampling` `--discrete_flow_shift``--model_prediction_type` --guidance_scale` are not necessary for SD3 fine-tuning. +- Use `--vae` instead of `--ae` if necessary. __This option is not necessary for SD3.__ VAE is included in the standard SD3 model. +- `--disable_mmap_load_safetensors` is available. __This option significantly reduces the memory usage when loading models for Windows users.__ +- `--cpu_offload_checkpointing` is not available for SD3 fine-tuning. +- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are available same as LoRA training. +- `--pos_emb_random_crop_rate` and `--enable_scaled_pos_embed` are available for SD3.5M fine-tuning. +- Training text encoders is available with `--train_text_encoder` option, similar to SDXL training. + - CLIP-L and G can be trained with `--train_text_encoder` option. Training T5XXL needs `--train_t5xxl` option. + - If you use the cached text encoder outputs for T5XXL with training CLIP-L and G, specify `--use_t5xxl_cache_only`. This option enables to use the cached text encoder outputs for T5XXL only. + - The learning rates for CLIP-L, CLIP-G and T5XXL can be specified separately. `--text_encoder_lr1`, `--text_encoder_lr2` and `--text_encoder_lr3` are available. + +### Extract LoRA from SD3 Models + +Not available yet. + +### Convert SD3 LoRA + +Not available yet. + +### Merge LoRA to SD3 checkpoint + +Not available yet. + +--- + [__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 [日本語版READMEはこちら](./README-ja.md) +The development version is in the `dev` branch. Please check the dev branch for the latest changes. + +FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch. + + For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! This repository contains the scripts for: @@ -137,6 +871,221 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Working in progress + +- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. + - bitsandbytes, transformers, accelerate and huggingface_hub are updated. + - If you encounter any issues, please report them. + +- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details. + - Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717) + +- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) + +- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! + - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). + +- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. + +- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! + +- Improvements in OFT (Orthogonal Finetuning) Implementation + 1. Optimization of Calculation Order: + - Changed the calculation order in the forward method from (Wx)R to W(xR). + - This has improved computational efficiency and processing speed. + 2. Correction of Bias Application: + - In the previous implementation, R was incorrectly applied to the bias. + - The new implementation now correctly handles bias by using F.conv2d and F.linear. + 3. Efficiency Enhancement in Matrix Operations: + - Introduced einsum in both the forward and merge_to methods. + - This has optimized matrix operations, resulting in further speed improvements. + 4. Proper Handling of Data Types: + - Improved to use torch.float32 during calculations and convert results back to the original data type. + - This maintains precision while ensuring compatibility with the original model. + 5. Unified Processing for Conv2d and Linear Layers: + - Implemented a consistent method for applying OFT to both layer types. + - These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability. + + - Additional Information + * Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation. + + * Performance Improvement: Training speed has been improved by approximately 30%. + + * Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL). + +- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! + - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. + - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. + +https://github.com/kohya-ss/sd-scripts/pull/1393 +- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! + +- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! + +- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! + +- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! + - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. + - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. + - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. + - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts. + +- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. + - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. + - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using Adafactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. + +- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! + - LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details. + - Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder. + - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. + - `network_module` `networks.lora` and `networks.dylora` are available. + +- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru! + - The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file. + - See [About masked loss](./docs/masked_loss_README.md) for details. + +- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - Specify the learning rate and dim (rank) for each block. + - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). + +- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath! + - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended. + - When specifying from the command line, use `=` like `--learning_rate=-1e-7`. + +- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! + - Some settings, such as API keys and directory specifications, are not output due to security issues. + +- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! + +- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf! + - This resolves the issue where the order of data loading from DataSet changes when resuming training. + - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before). + - If `--resume` is specified, the step saved in the state is used. + - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`). + +- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! + - It seems that the model file loading is faster in the WSL environment etc. + - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. + +- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath! + +- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! + +- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO! + +- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th! + +- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) + +- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported. + +- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 + - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は Adafactor のみ対応しています。また gradient accumulation は使えません。 + - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 + - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 + - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。 + +- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 + - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 + - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 + - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 + +- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。 + - LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。 + - `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。 + - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など + - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 + +- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。 + - 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。 + - 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。 + +- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - ブロックごとに学習率および dim (rank) を指定することができます。 + - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 + +- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。 + - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。 + - コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。 + +- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 + - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 + +- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 + +- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。 + - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。 + - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます) + - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。 + - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。 + +- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 + - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 + - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 + +- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。 + +- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。 + +- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。 + +- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。 + +- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) + +- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。 + + +### Oct 27, 2024 / 2024-10-27: + +- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient). +- This will be included in the next release. +- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。 +- これは次回リリースに含まれます。 + +### Oct 26, 2024 / 2024-10-26: + +- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue. +- It will be included in the next release. + +- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。 +- 以上は次回リリースに含まれます。 + +### Sep 13, 2024 / 2024-09-13: + +- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580). +- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details. +- `sdxl_merge_lora.py` also supports LBW. +- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW. +- These will be included in the next release. + +- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。 +- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。 +- `sdxl_merge_lora.py` でも LBW がサポートされました。 +- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。 +- 以上は次回リリースに含まれます。 + +### Jun 23, 2024 / 2024-06-23: + +- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.) + +- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。) + ### Apr 7, 2024 / 2024-04-07: v0.8.7 - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. diff --git a/_typos.toml b/_typos.toml index ae9e06b1..bbf7728f 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,6 +2,7 @@ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started [default.extend-identifiers] +ddPn08="ddPn08" [default.extend-words] NIN="NIN" @@ -27,6 +28,7 @@ rik="rik" koo="koo" yos="yos" wn="wn" +hime="hime" [files] diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 83bea329..66a50dc0 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d * `batch_size` * This corresponds to the command-line argument `--train_batch_size`. +* `max_bucket_reso`, `min_bucket_reso` + * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`. These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index cc74c341..0ed95e0e 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `batch_size` * コマンドライン引数の `--train_batch_size` と同等です。 +* `max_bucket_reso`, `min_bucket_reso` + * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。 これらの設定はデータセットごとに固定です。 つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md new file mode 100644 index 00000000..58f042c3 --- /dev/null +++ b/docs/masked_loss_README-ja.md @@ -0,0 +1,57 @@ +## マスクロスについて + +マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。 +たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。 + +マスクロスのマスクには、二種類の指定方法があります。 + +- マスク画像を用いる方法 +- 透明度(アルファチャネル)を使用する方法 + +なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。 + +### マスク画像を用いる方法 + +学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 + +- 学習画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- マスク画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + +マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 + +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 + +### 透明度(アルファチャネル)を使用する方法 + +学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 + +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※それぞれの画像は透過PNG + +学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## 学習時の注意事項 + +- 現時点では DreamBooth 方式の dataset のみ対応しています。 +- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。 +- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証) +- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。 diff --git a/docs/masked_loss_README.md b/docs/masked_loss_README.md new file mode 100644 index 00000000..3ac5ad21 --- /dev/null +++ b/docs/masked_loss_README.md @@ -0,0 +1,56 @@ +## Masked Loss + +Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background. + +There are two ways to specify the mask for masked loss. + +- Using a mask image +- Using transparency (alpha channel) of the image + +The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html). + +### Using a mask image + +This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image. + +- Training image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- Mask image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + +The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently. + +Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details. + +### Using transparency (alpha channel) of the image + +The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5). + +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※Each image is a transparent PNG + +Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this. + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## Notes on training + +- At the moment, only the dataset in the DreamBooth method is supported. +- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary. +- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified) +- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched. diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index d186bf24..cfa5a7d1 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 詳細については各自お調べください。 - 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。 + 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。 ### オプティマイザの指定について diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 7e00278c..1bc47e0f 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 有关详细信息,请自行研究。 - 要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。 + 要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。 ### 关于指定优化器 使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。 diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index a05f87f5..1bd8e4ae 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -185,7 +185,7 @@ for img_file in img_files: ### Creating a dataset configuration file -You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. +You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. ```toml [general] diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 2205a773..55c80c4b 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。 * `--network_args` * 複数の引数を指定できます。後述します。 +* `--alpha_mask` + * 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 @@ -181,16 +183,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf 詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。 -SDXLは現在サポートしていません。 - フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 +SDXL では down/up 9 個、middle 3 個の値を指定してください。 + `--network_args` で以下の引数を指定してください。 - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 - - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 + - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個(SDXL では 9 個)の数値を指定します。 - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 -- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 +- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します(SDXL の場合は 3 個)。 - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 - `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 @@ -215,6 +217,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 +SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。 +対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。 + `--network_args` で以下の引数を指定してください。 - `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 diff --git a/docs/train_network_README-zh.md b/docs/train_network_README-zh.md index ed7a0c4e..830014f7 100644 --- a/docs/train_network_README-zh.md +++ b/docs/train_network_README-zh.md @@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中 * 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。 * `--network_args` * 可以指定多个参数。将在下面详细说明。 +* `--alpha_mask` + * 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) 当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。 diff --git a/fine_tune.py b/fine_tune.py index c7e6bbd2..17608706 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,7 +10,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -39,6 +39,7 @@ from library.custom_train_functions import ( scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +import library.strategy_sd as strategy_sd def train(args): @@ -52,7 +53,15 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -81,16 +90,18 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -165,8 +176,9 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -192,6 +204,9 @@ def train(args): else: text_encoder.eval() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -214,7 +229,11 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -310,10 +329,19 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -338,23 +366,21 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -366,9 +392,10 @@ def train(args): else: target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -376,11 +403,11 @@ def train(args): if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -399,7 +426,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -424,7 +451,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -437,7 +464,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -462,7 +489,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: @@ -471,7 +500,7 @@ def train(args): accelerator.end_training() - if is_main_process and (args.save_state or args.save_state_on_train_end): + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 0389da38..019c737a 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -11,6 +11,7 @@ import cv2 import torch from library.device_utils import init_ipex, get_preferred_device + init_ipex() from torchvision import transforms @@ -18,8 +19,10 @@ from torchvision import transforms import library.model_util as model_util import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) DEVICE = get_preferred_device() @@ -89,7 +92,9 @@ def main(args): # bucketのサイズを計算する max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + assert ( + len(max_reso) == 2 + ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" bucket_manager = train_util.BucketManager( args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps @@ -107,7 +112,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument( + "--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", @@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser: help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument( "--full_path", @@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser: help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", ) parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + "--flip_aug", + action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する", + ) + parser.add_argument( + "--alpha_mask", + type=str, + default="", + help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する", ) parser.add_argument( "--skip_existing", diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd6..cbc3d2d6 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image @@ -112,7 +114,6 @@ def main(args): # モデルを読み込む if args.onnx: - import torch import onnx import onnxruntime as ort diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py new file mode 100644 index 00000000..7ab224f1 --- /dev/null +++ b/flux_minimal_inference.py @@ -0,0 +1,576 @@ +# Minimum Inference Code for FLUX + +import argparse +import datetime +import math +import os +import random +from typing import Callable, List, Optional +import einops +import numpy as np + +import torch +from tqdm import tqdm +from PIL import Image +import accelerate +from transformers import CLIPTextModel +from safetensors.torch import load_file + +from library import device_utils +from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux + +init_ipex() + + +from library.utils import setup_logging, str_to_dtype + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import networks.lora_flux as lora_flux +from library import flux_models, flux_utils, sd3_utils, strategy_flux + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, +): + # this is ignored for schnell + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # prepare classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + + pred = model( + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=b_vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=b_t5_attn_mask, + ) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + + img = img + (t_prev - t_curr) * pred + + return img + + +def do_sample( + accelerator: Optional[accelerate.Accelerator], + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: torch.Tensor, + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + t5_attn_mask: Optional[torch.Tensor], + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, +): + logger.info(f"num_steps: {num_steps}") + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise( + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, + ) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): + x = denoise( + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, + ) + + return x + + +def generate_image( + model, + clip_l: CLIPTextModel, + t5xxl, + ae, + prompt: str, + seed: Optional[int], + image_width: int, + image_height: int, + steps: Optional[int], + guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, +): + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + + # make first noise with packed shape + # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 + packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=noise_dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + # prepare img and img ids + + # this is needed only for img2img + # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + # if img.shape[0] == 1 and bs > 1: + # img = repeat(img, "1 ... -> bs ...", bs=bs) + + # txt2img only needs img_ids + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare fp8 models + if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + + # prepare embeddings + logger.info("Encoding prompts...") + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask + + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None + + # NaN check + if torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + if args.offload: + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + # del clip_l, t5xxl + device_utils.clean_memory() + + # generate image + logger.info("Generating image...") + model = model.to(device) + if steps is None: + steps = 4 if is_schnell else 50 + + img_ids = img_ids.to(device) + t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + + x = do_sample( + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, + ) + if args.offload: + model = model.cpu() + # del model + device_utils.clean_memory() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if is_fp8(ae_dtype): + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + if args.offload: + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + img.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 768 # 1024 + target_width = 1360 # 1024 + + # steps = 50 # 28 # 50 + # guidance_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--ae", type=str, required=False) + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") + parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") + parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") + parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") + parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") + parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + guidance_scale = args.guidance + + def is_fp8(dt): + return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] + + dtype = str_to_dtype(args.dtype) + clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype) + t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) + ae_dtype = str_to_dtype(args.ae_dtype, dtype) + flux_dtype = str_to_dtype(args.flux_dtype, dtype) + + logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") + + loading_device = "cpu" if args.offload else device + + use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]] + if any(use_fp8): + accelerator = accelerate.Accelerator(mixed_precision="bf16") + else: + accelerator = None + + # load clip_l + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l.eval() + + logger.info(f"Loading t5xxl from {args.t5xxl}...") + t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) + t5xxl.eval() + + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) + + # DiT + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) + model.eval() + logger.info(f"Casting model to {flux_dtype}") + model.to(flux_dtype) # make sure model is dtype + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") + + t5xxl_max_length = 256 if is_schnell else 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # AE + ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device) + ae.eval() + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) + + # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + + if not args.interactive: + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = target_width + height = target_height + steps = None + guidance = args.guidance + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) + + logger.info("Done!") diff --git a/flux_train.py b/flux_train.py new file mode 100644 index 00000000..fced3bef --- /dev/null +++ b/flux_train.py @@ -0,0 +1,849 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +from concurrent.futures import ThreadPoolExecutor +import copy +import math +import os +from multiprocessing import Value +import time +from typing import List, Optional, Tuple, Union +import toml + +from tqdm import tqdm + +import torch +import torch.nn as nn +from library import utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + flux.requires_grad_(True) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(flux), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + flux = accelerator.unwrap_model(flux) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_control_net.py b/flux_train_control_net.py new file mode 100644 index 00000000..9d36a41d --- /dev/null +++ b/flux_train_control_net.py @@ -0,0 +1,877 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value +from typing import List, Optional, Tuple, Union + +import toml +import torch +import torch.nn as nn +from tqdm import tqdm + +from library import utils +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +from accelerate.utils import set_seed + +import library.train_util as train_util +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + BlueprintGenerator, + ConfigSanitizer, +) +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, args.conditioning_data_dir, args.caption_extension + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + is_schnell, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + flux.requires_grad_(False) + + # load controlnet + controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype + controlnet = flux_utils.load_controlnet( + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + ) + controlnet.train() + + if args.gradient_checkpointing: + if not args.deepspeed: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + # ControlNet only has two blocks, so we can keep it on GPU + # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) + else: + flux.to(accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(controlnet) + name_and_params = list(controlnet.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(controlnet.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + controlnet.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + controlnet = accelerator.prepare(controlnet) # , device_placement=[not is_swapping_blocks]) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + flux.prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = ( + flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width) + .to(device=accelerator.device) + .to(weight_dtype) + ) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + with accelerator.autocast(): + block_samples, block_single_samples = controlnet( + img=packed_noisy_model_input, + img_ids=img_ids, + controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype), + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + loss = train_util.conditional_loss( + model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None + ) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + flux_train_utils.sample_images( + accelerator, + args, + None, + global_step, + flux, + ae, + [clip_l, t5xxl], + sample_prompts_te_outputs, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(controlnet), + ) + + flux_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet + ) + optimizer_train_fn() + + is_main_process = accelerator.is_main_process + # if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--double_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--single_blocks_to_swap", + type=int, + default=None, + help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください", + ) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flux_train_network.py b/flux_train_network.py new file mode 100644 index 00000000..75e975ba --- /dev/null +++ b/flux_train_network.py @@ -0,0 +1,583 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator + +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import train_network +from library import ( + flux_models, + flux_train_utils, + flux_utils, + sd3_train_utils, + strategy_base, + strategy_flux, + train_util, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class FluxNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + # deprecated split_mode option + if args.split_mode: + if args.blocks_to_swap is not None: + logger.warning( + "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored." + " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。" + ) + else: + logger.warning( + "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set." + " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。" + ) + args.blocks_to_swap = 18 # 18 is safe for most cases + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 FLUX model") + else: + logger.info( + "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + model.to(torch.float8_e4m3fn) + + # if args.split_mode: + # model = self.prepare_split_model(model, weight_dtype, accelerator) + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs + ) + # return + + """ + class FluxUpperLowerWrapper(torch.nn.Module): + def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): + super().__init__() + self.flux_upper = flux_upper + self.flux_lower = flux_lower + self.target_device = device + + def prepare_block_swap_before_forward(self): + pass + + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): + self.flux_lower.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_upper.to(self.target_device) + img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) + self.flux_upper.to("cpu") + clean_memory_on_device(self.target_device) + self.flux_lower.to(self.target_device) + return self.flux_lower(img, txt, vec, pe, txt_attention_mask) + + wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) + clean_memory_on_device(accelerator.device) + flux_train_utils.sample_images( + accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs + ) + clean_memory_on_device(accelerator.device) + """ + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--split_mode", + action="store_true", + # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", + help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." + " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = FluxNetworkTrainer() + trainer.train(args) diff --git a/gen_img.py b/gen_img.py index 4fe89871..d0c99bd1 100644 --- a/gen_img.py +++ b/gen_img.py @@ -43,8 +43,8 @@ from diffusers import ( ) from einops import rearrange from tqdm import tqdm -from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +from accelerate import init_empty_weights import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -58,6 +58,7 @@ import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.sdxl_original_control_net import SdxlControlNet from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL @@ -86,7 +87,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") @@ -351,8 +353,8 @@ class PipelineLike: self.token_replacements_list.append({}) # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] + self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない self.gradual_latent: GradualLatent = None @@ -541,7 +543,7 @@ class PipelineLike: else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - if self.control_net_lllites: + if self.control_net_lllites or (self.control_nets and self.is_sdxl): # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] @@ -730,7 +732,12 @@ class PipelineLike: num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if not self.is_sdxl: + guided_hints = original_control_net.get_guided_hints( + self.control_nets, num_latent_input, batch_size, clip_guide_images + ) + else: + clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1] each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) if self.control_net_lllites: @@ -792,7 +799,7 @@ class PipelineLike: latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + # disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo if self.control_net_lllites: for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): if not enabled or ratio >= 1.0: @@ -801,9 +808,16 @@ class PipelineLike: logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False + if self.control_nets and self.is_sdxl: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + each_control_net_enabled[j] = False # predict the noise residual - if self.control_nets and self.control_net_enabled: + if self.control_nets and self.control_net_enabled and not self.is_sdxl: if regional_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -822,6 +836,31 @@ class PipelineLike: text_embeddings, text_emb_last, ).sample + elif self.control_nets: + input_resi_add_list = [] + mid_add_list = [] + for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled): + if not enbld: + continue + input_resi_add, mid_add = control_net( + latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images + ) + input_resi_add_list.append(input_resi_add) + mid_add_list.append(mid_add) + if len(input_resi_add_list) == 0: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + if len(input_resi_add_list) > 1: + # get mean of input_resi_add_list and mid_add_list + input_resi_add_mean = [] + for i in range(len(input_resi_add_list[0])): + input_resi_add_mean.append( + torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0)) + ) + input_resi_add = input_resi_add_mean + mid_add = torch.mean(torch.stack(mid_add_list), dim=0) + + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add) elif self.is_sdxl: noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) else: @@ -1435,6 +1474,7 @@ class BatchDataBase(NamedTuple): clip_prompt: str guide_image: Any raw_prompt: str + file_name: Optional[str] class BatchDataExt(NamedTuple): @@ -1493,8 +1533,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -1825,16 +1863,37 @@ def main(args): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] + control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = [] if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + if not is_sdxl: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + else: + for i, model_file in enumerate(args.control_net_models): + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + logger.info(f"loading SDXL ControlNet: {model_file}") + from safetensors.torch import load_file + + state_dict = load_file(model_file) + + logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}") + with init_empty_weights(): + control_net = SdxlControlNet(multiplier=multiplier) + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_nets.append((control_net, ratio)) control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] if args.control_net_lllite_models: @@ -2316,7 +2375,7 @@ def main(args): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _, _), ( width, height, @@ -2339,6 +2398,7 @@ def main(args): prompts = [] negative_prompts = [] raw_prompts = [] + filenames = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2371,7 +2431,7 @@ def main(args): all_guide_images_are_same = True for i, ( _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename), _, ) in enumerate(batch): prompts.append(prompt) @@ -2379,6 +2439,7 @@ def main(args): seeds.append(seed) clip_prompts.append(clip_prompt) raw_prompts.append(raw_prompt) + filenames.append(filename) if init_image is not None: init_images.append(init_image) @@ -2478,8 +2539,8 @@ def main(args): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames) ): if highres_fix: seed -= 1 # record original seed @@ -2505,17 +2566,23 @@ def main(args): metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-left", str(crop_left)) - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + if filename is not None: + fln = filename else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + if fln.endswith(".webp"): + image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy + else: + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: try: @@ -2562,6 +2629,7 @@ def main(args): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + filename = None if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing @@ -2783,6 +2851,12 @@ def main(args): logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue + m = re.match(r"f (.+)", parg, re.IGNORECASE) + if m: # filename + filename = m.group(1) + logger.info(f"filename: {filename}") + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") @@ -2873,7 +2947,16 @@ def main(args): b1 = BatchData( False, BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + global_step, + prompt, + negative_prompt, + seed, + init_image, + mask_image, + clip_prompt, + guide_image, + raw_prompt, + filename, ), BatchDataExt( width, @@ -2916,7 +2999,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() add_logging_arguments(parser) - + parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" ) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c40f1a0..04db4e9b 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2216,8 +2216,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 00000000..b5afa236 --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,138 @@ +import math +import torch +from transformers import Adafactor + +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/config_util.py b/library/config_util.py index 2f627259..a09d2c7c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -10,13 +10,7 @@ import json from pathlib import Path # from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous @@ -78,6 +72,9 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -86,11 +83,13 @@ class DreamBoothSubsetParams(BaseSubsetParams): class_tokens: Optional[str] = None caption_extension: str = ".caption" cache_info: bool = False + alpha_mask: bool = False @dataclass class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None + alpha_mask: bool = False @dataclass @@ -102,14 +101,13 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -191,11 +189,13 @@ class ConfigSanitizer: "keep_tokens": int, "keep_tokens_separator": str, "secondary_separator": str, + "caption_separator": str, "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "custom_attributes": dict, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -212,11 +212,13 @@ class ConfigSanitizer: DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, "is_reg": bool, + "alpha_mask": bool, } # FT means FineTuning FT_SUBSET_DISTINCT_SCHEMA = { Required("metadata_file"): str, "image_dir": str, + "alpha_mask": bool, } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, @@ -480,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) datasets.append(dataset) val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -488,17 +490,17 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu if dataset_blueprint.params.validation_split <= 0.0: continue if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset - + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) # print info @@ -543,6 +545,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """), " ") if is_dreambooth: @@ -564,6 +568,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu print("Validation dataset") print_info(val_datasets) + if len(val_datasets) > 0: + info = "" + + for i, dataset in enumerate(val_datasets): + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Validation Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + """ + ), + " ", + ) + + logger.info(f"{info}") + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,7 +622,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - print(f"[Validation Dataset {i}]") + logger.info(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 00000000..84c2b743 --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,227 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from library.device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + if self.debug: + print("Prepare block devices before forward") + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index e7c17c98..9a7c21a3 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -98,10 +98,13 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n return loss -def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss @@ -482,12 +485,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): def apply_masked_loss(loss, batch): - # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") + else: + return loss # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 loss = loss * mask_image return loss diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 00000000..328ad481 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,1493 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + + +import math +import os +import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from library import utils +from library.device_utils import clean_memory_on_device, init_ipex + +init_ipex() + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +from library import custom_offloading_utils + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt blocks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) + + else: + return self._forward(img, txt, vec, pe, txt_attention_mask) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + # compute attention + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) + else: + return self._forward(x, vec, pe, txt_attention_mask) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + if block_controlnet_single_hidden_states is not None: + controlnet_single_depth = len(block_controlnet_single_hidden_states) + + if not self.blocks_to_swap: + for block_idx, block in enumerate(self.double_blocks): + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None and controlnet_depth > 0: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + + img = torch.cat((txt, img), 1) + for block_idx, block in enumerate(self.single_blocks): + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_hidden_states is not None and controlnet_depth > 0: + img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: + img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + img = img[:, txt.shape[1] :, ...] + + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + return img + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(controlnet_single_depth) + ] + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.controlnet_blocks_for_single = nn.ModuleList([]) + for _ in range(controlnet_single_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks_for_single.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> tuple[tuple[Tensor]]: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_samples = () + block_single_samples = () + if not self.blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_samples = block_samples + (img,) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + block_single_samples = block_single_samples + (img,) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + controlnet_block_samples = () + controlnet_single_block_samples = () + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): + block_sample = controlnet_block(block_sample) + controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) + + return controlnet_block_samples, controlnet_single_block_samples + + +""" +class FluxUpper(nn.Module): + "" + Transformer model for flow matching on sequences. + "" + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + return img, txt, vec, pe + + +class FluxLower(nn.Module): + "" + Transformer model for flow matching on sequences. + "" + + def __init__(self, params: FluxParams): + super().__init__() + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.out_channels = params.in_channels + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + for block in self.single_blocks: + block.enable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + for block in self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor | None = None, + pe: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> Tensor: + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img +""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 00000000..f7f06c5c --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,619 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image +from safetensors.torch import save_file + +from library import flux_models, flux_utils, strategy_base, train_util +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# region sample images + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, + controlnet=None +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if controlnet is not None: + controlnet = accelerator.unwrap_model(controlnet) + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + controlnet + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + controlnet + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: Optional[List[CLIPTextModel]], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + controlnet +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + text_encoder_conds = [] + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + + with accelerator.autocast(), torch.no_grad(): + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, + controlnet: Optional[flux_models.ControlNetFlux] = None, + controlnet_img: Optional[torch.Tensor] = None, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() + if controlnet is not None: + block_samples, block_single_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_img, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + else: + block_samples = None + block_single_samples = None + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + + model.prepare_block_swap_before_forward() + return img + + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, _, h, w = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" + ) + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 00000000..8be1d63e --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,488 @@ +import json +import os +from dataclasses import replace +from typing import List, Optional, Tuple, Union + +import einops +import torch +from accelerate import init_empty_weights +from safetensors import safe_open +from safetensors.torch import load_file +from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import flux_models +from library.utils import load_safetensors + +MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" + + +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + # if the key has annoying prefix, remove it + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths + + +def load_flow_model( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, flux_models.Flux]: + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") + with torch.device("meta"): + params = flux_models.configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = flux_models.Flux(params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) + logger.info("Converted Diffusers to BFL") + + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return is_schnell, model + + +def load_ae( + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> flux_models.AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_controlnet( + ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device(device): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet + + +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_t5xxl( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(num_double_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(num_single_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion diff --git a/library/ipex/attention.py b/library/ipex/attention.py index d989ad53..2bc62f65 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -5,7 +5,7 @@ from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82e..8896c047 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -6,8 +6,10 @@ import os from typing import List, Optional, Tuple, Union import safetensors from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) r""" @@ -55,12 +57,18 @@ ARCH_SD_V1 = "stable-diffusion-v1" ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. +# ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -113,7 +121,12 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + sd3: Optional[str] = None, + flux: Optional[str] = None, ): + """ + sd3: only supports "m", flux: only supports "dev" + """ # if state_dict is None, hash is not calculated metadata = {} @@ -126,6 +139,13 @@ def build_metadata( if sdxl: arch = ARCH_SD_XL_V1_BASE + elif sd3 is not None: + arch = ARCH_SD3_M + "-" + sd3 + elif flux is not None: + if flux == "dev": + arch = ARCH_FLUX_1_DEV + else: + arch = ARCH_FLUX_1_UNKNOWN elif v2: if v_parameterization: arch = ARCH_SD_V2_768_V @@ -142,9 +162,12 @@ def build_metadata( metadata["modelspec.architecture"] = arch if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if flux is not None: + # Flux + impl = IMPL_FLUX + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -202,7 +225,7 @@ def build_metadata( reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default - if sdxl: + if sdxl or sd3 is not None or flux is not None: reso = 1024 elif v2 and v_parameterization: reso = 768 @@ -213,7 +236,9 @@ def build_metadata( metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - if v_parameterization: + if flux is not None: + del metadata["modelspec.prediction_type"] + elif v_parameterization: metadata["modelspec.prediction_type"] = PRED_TYPE_V else: metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON @@ -236,7 +261,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): logger.error(f"Internal error: some metadata values are None: {metadata}") - + return metadata @@ -250,7 +275,7 @@ def get_title(metadata: dict) -> Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 00000000..e4a93186 --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,1428 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from ast import Tuple +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import partial +import math +from types import SimpleNamespace +from typing import Dict, List, Optional, Union +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import custom_offloading_utils +from library.device_utils import clean_memory_on_device + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region mmdit + + +@dataclass +class SD3Params: + patch_size: int + depth: int + num_patches: int + pos_embed_max_size: int + adm_in_channels: int + qk_norm: Optional[str] + x_block_self_attn_layers: list[int] + context_embedder_in_features: int + context_embedder_out_features: int + model_type: str + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): + """ + This function is contributed by KohakuBlueleaf. Thanks for the contribution! + + Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions + when the resolution differs from the training resolution. + + Args: + embed_dim (int): Dimension of the positional embedding. + grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. + cls_token (bool): Whether to include class token. Defaults to False. + extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. + sample_size (int): Reference resolution (typically training resolution). Defaults to 64. + base_size (int): Base grid size used during training. Defaults to 16. + + Returns: + numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or + (H*W + extra_tokens, embed_dim) if cls_token is True. + """ + # Convert grid_size to tuple if it's an integer + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + # Create normalized grid coordinates (0 to 1) + grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] + + # Calculate scaling factors for height and width + # This ensures that the central region matches the original resolution's embeddings + scale_h = base_size * grid_size[0] / (sample_size) + scale_w = base_size * grid_size[1] / (sample_size) + + # Calculate shift values to center the original resolution's embedding region + # This ensures that the central sample_size x sample_size region has similar + # positional embeddings to the original resolution + shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) + shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) + + # Apply scaling and shifting to create the final grid coordinates + grid_h = grid_h * scale_h - shift_h + grid_w = grid_w * scale_w - shift_w + + # Create 2D grid using meshgrid (note: w goes first) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + # # Calculate the starting indices for the central region + # # This is used for debugging/visualization of the central region + # st_h = (grid_size[0] - sample_size) // 2 + # st_w = (grid_size[1] - sample_size) // 2 + # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) + + # Reshape grid for positional embedding calculation + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + # Generate the sinusoidal positional embeddings + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + # Add zeros for extra tokens (e.g., [CLS] token) if required + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +# if __name__ == "__main__": +# # This is what you get when you load SD3.5 state dict +# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( +# 1536, [384, 384], sample_size=64, base_size=16 +# )).float().unsqueeze(0) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=False, + ): + # dynamic_img_pad and norm is omitted in SD3.5 + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: Optional[str] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + # prepare pos_embed for latent size * 2 + POS_EMBED_MAX_RATIO = 1.5 + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_in_features: Optional[int] = None, + context_embedder_out_features: Optional[int] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], + qkv_bias: bool = True, + pos_emb_random_crop_rate: float = 0.0, + use_scaled_pos_embed: bool = False, + pos_embed_latent_sizes: Optional[list[int]] = None, + model_type: str = "sd3m", + ): + super().__init__() + self._model_type = model_type + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_embedder_in_features is not None: + self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) + else: + self.context_embedder = nn.Identity() + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + self.blocks_to_swap = None + self.offloader = None + self.num_blocks = len(self.joint_blocks) + + def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): + self.use_scaled_pos_embed = use_scaled_pos_embed + + if self.use_scaled_pos_embed: + # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved + # self.pos_embed = None + # move pos_embed to CPU to free up memory up to 0.4 GB + self.pos_embed = self.pos_embed.cpu() + + # remove duplicates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) + latent_sizes = sorted(latent_sizes) + + patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] + + # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape + max_areas = [] + for i in range(1, len(patched_sizes)): + prev_area = patched_sizes[i - 1] ** 2 + area = patched_sizes[i] ** 2 + max_areas.append((prev_area + area) // 2) + + # area of the last latent size, if the latent size exceeds this, error will be raised + max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) + # print("max_areas", max_areas) + + self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] + + self.resolution_pos_embeds = {} + for patched_size in patched_sizes: + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") + + else: + self.resolution_area_to_latent_size = None + self.resolution_pos_embeds = None + + @property + def model_type(self): + return self._model_type + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: # should not happen + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + + # select pos_embed size based on area + area = h * w + patched_size = None + for area_, patched_size_ in self.resolution_area_to_latent_size: + if area <= area_: + patched_size = patched_size_ + break + if patched_size is None: + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] + + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO + if h > pos_embed_size or w > pos_embed_size: + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size + logger.warning( + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + ) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) + + if not random_crop: + top = (pos_embed_size - h) // 2 + left = (pos_embed_size - w) // 2 + else: + top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() + left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() + + if pos_embed.device != device: + pos_embed = pos_embed.to(device) + # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. + self.resolution_pos_embeds[patched_size] = pos_embed # update device + if pos_embed.dtype != dtype: + pos_embed = pos_embed.to(dtype) + self.resolution_pos_embeds[patched_size] = pos_embed # update dtype + + spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + # print( + # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" + # ) + return spatial_pos_embed + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + + assert ( + self.blocks_to_swap <= self.num_blocks - 2 + ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." + + self.offloader = custom_offloading_utils.ModelOffloader( + self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + ) + print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.prepare_block_devices_before_forward(self.joint_blocks) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) + + B, C, H, W = x.shape + + # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + if not self.use_scaled_pos_embed: + pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + else: + # print(f"Using scaled pos_embed for size {H}x{W}") + pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) + x = self.x_embedder(x) + pos_embed + del pos_embed + + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 + ) + + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + for block_idx, block in enumerate(self.joint_blocks): + self.offloader.wait_for_block(block_idx) + + context, x = block(context, x, c) + + self.offloader.submit_move_blocks(self.joint_blocks, block_idx) + + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=params.pos_embed_max_size, + patch_size=params.patch_size, + in_channels=16, + adm_in_channels=params.adm_in_channels, + context_embedder_in_features=params.context_embedder_in_features, + context_embedder_out_features=params.context_embedder_out_features, + depth=params.depth, + mlp_ratio=4, + qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, + num_patches=params.num_patches, + attn_mode=attn_mode, + model_type=params.model_type, + ) + return mmdit + + +# endregion + +# region VAE + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + # @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + # @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +# endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 00000000..c4079884 --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,945 @@ +import argparse +import math +import os +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union + +import torch +from safetensors.torch import save_file +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils, strategy_base, train_util + + +def save_models( + ckpt_path: str, + mmdit: Optional[sd3_models.MMDiT], + vae: Optional[sd3_models.SDVAE], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + # do not support unified checkpoint format for now + # if clip_l is not None: + # update_sd("text_encoders.clip_l.", clip_l.state_dict()) + # if clip_g is not None: + # update_sd("text_encoders.clip_g.", clip_g.state_dict()) + # if t5xxl is not None: + # update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + if clip_l is not None: + clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors") + save_file(clip_l.state_dict(), clip_l_path) + if clip_g is not None: + clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors") + save_file(clip_g.state_dict(), clip_g_path) + if t5xxl is not None: + t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", + ) + parser.add_argument( + "--save_t5xxl", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=256, + help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256", + ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) + parser.add_argument( + "--enable_scaled_pos_embed", + action="store_true", + help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M" + " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", + ) + + # Dependencies of Diffusers noise sampler has been removed for clarity in training + + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +# temporary copied from sd3_minimal_inferece.py + + +def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + else: + generator = None + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_all_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + # with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + mmdit.prepare_block_swap_before_forward() + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + mmdit.prepare_block_swap_before_forward() + return x + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, +): + if steps == 0: + if not args.sample_at_first: + return + else: + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return + + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = train_util.load_prompts(args.sample_prompts) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(), accelerator.autocast(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + ) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + clean_memory_on_device(accelerator.device) + with accelerator.autocast(), torch.no_grad(): + # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype. + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +# endregion + + +def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift + + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) + + indices = (u * (t_max - t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to flowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1, 1, 1, 1) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 00000000..1861dfbc --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,302 @@ +from dataclasses import dataclass +import math +import re +from typing import Dict, List, Optional, Union +import torch +import safetensors +from safetensors.torch import load_file +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models + +# TODO move some of functions to model_util.py +from library import sdxl_model_util + +# region models + +# TODO remove dependency on flux_utils +from library.utils import load_safetensors +from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl + + +def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): + logger.info(f"Analyzing state dict state...") + + # analyze configs + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None + + # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) + x_block_self_attn_layers = [] + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") + for key in list(state_dict.keys()): + m = re_attn.search(key) + if m: + x_block_self_attn_layers.append(int(m.group(1))) + + context_embedder_in_features = context_shape[1] + context_embedder_out_features = context_shape[0] + + # only supports 3-5-large, medium or 3-medium + if qk_norm is not None: + if len(x_block_self_attn_layers) == 0: + model_type = "3-5-large" + else: + model_type = "3-5-medium" + else: + model_type = "3-medium" + + params = sd3_models.SD3Params( + patch_size=patch_size, + depth=depth, + num_patches=num_patches, + pos_embed_max_size=pos_embed_max_size, + adm_in_channels=adm_in_channels, + qk_norm=qk_norm, + x_block_self_attn_layers=x_block_self_attn_layers, + context_embedder_in_features=context_embedder_in_features, + context_embedder_out_features=context_embedder_out_features, + model_type=model_type, + ) + logger.info(f"Analyzed state dict state: {params}") + return params + + +def load_mmdit( + state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch" +) -> sd3_models.MMDiT: + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + params = analyze_state_dict_state(mmdit_sd) + with init_empty_weights(): + mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + clip_l_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + clip_l_sd = None + if clip_l_path is None: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_l_path is None: + logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided") + return None + + # load clip_l + logger.info("Building CLIP-L") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) + + if clip_l_sd is None: + logger.info(f"Loading state dict from {clip_l_path}") + clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + if "text_projection.weight" not in clip_l_sd: + logger.info("Adding text_projection.weight to clip_l_sd") + clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device) + + info = clip.load_state_dict(clip_l_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_clip_g( + clip_g_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + clip_g_sd = None + if state_dict is not None: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_g_path is None: + logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided") + return None + + # load clip_g + logger.info("Building CLIP-G") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) + + if clip_g_sd is None: + logger.info(f"Loading state dict from {clip_g_path}") + clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(clip_g_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-G: {info}") + return clip + + +def load_t5xxl( + t5xxl_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + t5xxl_sd = None + if state_dict is not None: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + elif t5xxl_path is None: + logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided") + return None + + return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd) + + +def load_vae( + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE(vae_dtype, device) + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype + return vae + + +# endregion + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 03b18256..9196eb0f 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -13,12 +13,20 @@ from tqdm import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util, train_util +from library import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) try: @@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: vae: AutoencoderKL, text_encoder: List[CLIPTextModel], tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], scheduler: SchedulerMixin, # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, @@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - def check_inputs(self, prompt, height, width, strength, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, - controlnet=None, + controlnet: sdxl_original_control_net.SdxlControlNet = None, controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) - - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None unet_dtype = self.unet.dtype dtype = unet_dtype @@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) crop_size = torch.zeros_like(orig_size) target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) # make conditionings + text_pool = text_pool.to(device, dtype) if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): @@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # FIXME SD1 ControlNet is not working # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f03f1bae..0466c1fa 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,5 @@ import torch +import safetensors from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file @@ -7,9 +8,11 @@ from typing import List from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet -from .utils import setup_logging +from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 @@ -163,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None): raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): # model_version is reserved for future use # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - try: - state_dict = load_file(ckpt_path, device=map_location) - except: - state_dict = load_file(ckpt_path) # prevent device invalid Error + if disable_mmap: + state_dict = safetensors.torch.load(open(ckpt_path, "rb").read()) + else: + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 00000000..3af45f4d --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sdxl_original_unet +from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 17c345a8..0aa07d0d 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F from einops import rearrange -from .utils import setup_logging +from library.utils import setup_logging setup_logging() import logging @@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel: self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 self.ds_ratio = ds_ratio - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. """ _self = self.delegate @@ -1209,6 +1209,8 @@ class InferSdxlUNet2DConditionModel: hs.append(h) h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add for module in _self.output_blocks: # Deep Shrink @@ -1217,7 +1219,11 @@ class InferSdxlUNet2DConditionModel: # print("upsample", h.shape, hs[-1].shape) h = resize_like(h, hs[-1]) - h = torch.cat([h, hs.pop()], dim=1) + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) h = call_module(module, h, emb, context) # Deep Shrink: in case of depth 0 diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e3..dc3887c3 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,16 +5,18 @@ from typing import Optional import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" @@ -44,6 +46,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, + args.disable_mmap_load_safetensors, ) # work on low-ram device @@ -60,7 +63,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model( - name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False ): # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path @@ -75,7 +78,7 @@ def _load_target_model( unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -323,7 +326,7 @@ def save_sd_model_on_epoch_end_or_stepwise( ) -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" ) @@ -332,6 +335,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): @@ -355,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # ) # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: @@ -369,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 00000000..358e42f1 --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,570 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +import re +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, + cache_to_disk: bool, + batch_size: Optional[int], + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + self._is_weighted = is_weighted + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + @property + def is_weighted(self): + return self._is_weighted + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def cache_suffix(self): + raise NotImplementedError + + def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _default_is_disk_cached_latents_expected( + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + multi_resolution: bool = False, + ): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + + try: + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from library import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + + if self.cache_to_disk: + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + for SD/SDXL + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) + + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", + ): + kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) + if flipped_latents_tensor is not None: + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) diff --git a/library/strategy_flux.py b/library/strategy_flux.py new file mode 100644 index 00000000..5e65927f --- /dev/null +++ b/library/strategy_flux.py @@ -0,0 +1,271 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from library import flux_utils, train_util +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class FluxTokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, t5_tokens, t5_attn_mask] + + +class FluxTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_t5_attn_mask: Optional[bool] = None, + ) -> List[torch.Tensor]: + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) + l_tokens, t5_tokens = tokens[:2] + t5_attn_mask = tokens[2] if len(tokens) > 2 else None + + # clip_l is None when using T5 only + if clip_l is not None and l_tokens is not None: + l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] + else: + l_pooled = None + + # t5xxl is None when using CLIP only + if t5xxl is not None and t5_tokens is not None: + # t5_out is [b, max length, 4096] + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) + else: + t5_out = None + txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one + + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer + + +class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_t5_attn_mask = apply_t5_attn_mask + + self.warn_fp8_weights = False + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + l_pooled = data["l_pooled"] + t5_out = data["t5_out"] + txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask + return [l_pooled, t5_out, txt_ids, t5_attn_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + if not self.warn_fp8_weights: + if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。" + ) + self.warn_fp8_weights = True + + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) + + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) + + +class FluxLatentsCachingStrategy(LatentsCachingStrategy): + FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for FluxTokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = FluxTokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 00000000..d0a3a68b --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,171 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from library import train_util +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + tokens = tokens.to(text_encoder.device) + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + @property + def cache_suffix(self) -> str: + return self.suffix + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 00000000..1d55fe21 --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,420 @@ +import os +import glob +import random +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel + +from library import sd3_utils, train_util +from library import sd3_models +from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: Optional[bool] = False, + apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, + ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ + clip_l, clip_g, t5xxl = models + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] + + if apply_lg_attn_mask is None: + apply_lg_attn_mask = self.apply_lg_attn_mask + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens + + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + + if l_tokens is None or clip_l is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + lg_pooled = None + l_attn_mask = None + g_attn_mask = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] + + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) + + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out + else: + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out + else: + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is None or t5_tokens is None: + t5_out = None + t5_attn_mask = None + else: + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) + + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out + else: + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask + + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "lg_out" not in npz: + return False + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used + return False + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] + + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + t5_attn_mask = data["t5_attn_mask"] + + # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + # always disable dropout during caching + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = tokens_and_masks[3].cpu().numpy() + g_attn_mask = tokens_and_masks[4].cpu().numpy() + t5_attn_mask = tokens_and_masks[5].cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] + lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 00000000..6b3e2afa --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,306 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ] + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] + + # apply weights + if weights_list[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py index a7f0ecd8..bf1b6731 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +from concurrent.futures import Future, ThreadPoolExecutor import datetime import importlib import json @@ -13,13 +14,15 @@ import shutil import time import typing from typing import ( + Any, + Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, - Union, + Union ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob @@ -31,10 +34,14 @@ import subprocess from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed + from tqdm import tqdm +from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -43,7 +50,11 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, DDPMScheduler, @@ -59,7 +70,7 @@ from diffusers import ( KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions +from library import custom_train_functions, sd3_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -68,11 +79,12 @@ import imagesize import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize @@ -84,10 +96,6 @@ logger = logging.getLogger(__name__) # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - HIGH_VRAM = False # checkpointファイル名 @@ -138,23 +146,9 @@ IMAGE_TRANSFORMS = transforms.Compose( ) TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths, is_train, validation_split, validation_seed): - if validation_seed is not None: - print(f"Using validation seed: {validation_seed}") - prevstate = random.getstate() - random.seed(validation_seed) - random.shuffle(paths) - random.setstate(prevstate) - else: - random.shuffle(paths) - - if is_train: - return paths[0:math.ceil(len(paths) * (1 - validation_split))] - else: - return paths[len(paths) - round(len(paths) * validation_split):] - -def split_train_val(paths, is_train, validation_split, validation_seed): +def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]: if validation_seed is not None: print(f"Using validation seed: {validation_seed}") prevstate = random.getstate() @@ -180,19 +174,26 @@ class ImageInfo: self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -394,6 +395,7 @@ class BaseSubset: def __init__( self, image_dir: Optional[str], + alpha_mask: Optional[bool], num_repeats: int, shuffle_caption: bool, caption_separator: str, @@ -412,8 +414,12 @@ class BaseSubset: caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator @@ -434,8 +440,13 @@ class BaseSubset: self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -445,6 +456,7 @@ class DreamBoothSubset(BaseSubset): class_tokens: Optional[str], caption_extension: str, cache_info: bool, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -463,11 +475,15 @@ class DreamBoothSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -486,6 +502,9 @@ class DreamBoothSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -506,6 +525,7 @@ class FineTuningSubset(BaseSubset): self, image_dir, metadata_file: str, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator, @@ -524,11 +544,15 @@ class FineTuningSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -547,6 +571,9 @@ class FineTuningSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -582,11 +609,15 @@ class ControlNetSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, + False, # alpha_mask num_repeats, shuffle_caption, caption_separator, @@ -605,6 +636,9 @@ class ControlNetSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -622,17 +656,12 @@ class ControlNetSubset(BaseSubset): class BaseDataset(torch.utils.data.Dataset): def __init__( self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier @@ -653,8 +682,6 @@ class BaseDataset(torch.utils.data.Dataset): self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_step: int = 0 @@ -674,6 +701,43 @@ class BaseDataset(torch.utils.data.Dataset): # caching self.caching_mode = None # None, 'latents', 'text' + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() + + def adjust_min_max_bucket_reso_by_steps( + self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int + ) -> Tuple[int, int]: + # make min/max bucket reso to be multiple of bucket_reso_steps + if min_bucket_reso % bucket_reso_steps != 0: + adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps + logger.warning( + f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}" + ) + min_bucket_reso = adjusted_min_bucket_reso + if max_bucket_reso % bucket_reso_steps != 0: + adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps + logger.warning( + f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}" + ) + max_bucket_reso = adjusted_max_bucket_reso + + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + return min_bucket_reso, max_bucket_reso + def set_seed(self, seed): self.seed = seed @@ -682,8 +746,16 @@ class BaseDataset(torch.utils.data.Dataset): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step @@ -888,6 +960,23 @@ class BaseDataset(torch.utils.data.Dataset): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -948,28 +1037,12 @@ class BaseDataset(torch.utils.data.Dataset): logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - self.shuffle_buckets() self._length = len(self.buckets_indices) @@ -1002,7 +1075,111 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def new_cache_latents(self, model: Any, accelerator: Accelerator): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batch: List[ImageInfo] = [] + current_condition = None + + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + + # remove image from memory + for info in batch: + info.image = None + + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) + + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue + + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] + + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) + + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") @@ -1011,9 +1188,26 @@ class BaseDataset(torch.utils.data.Dataset): # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1023,49 +1217,147 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= vae_batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs + if caching_strategy.cache_to_disk: + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True ): assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, + ) + + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, + ): # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching text encoder outputs.") + + tokenize_strategy = TokenizeStrategy.get_strategy() + + if batch_size is None: + batch_size = self.batch_size + image_infos = list(self.image_data.values()) logger.info("checking cache existence...") @@ -1073,13 +1365,14 @@ class BaseDataset(torch.utils.data.Dataset): for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix info.text_encoder_outputs_npz = te_out_npz if not is_main_process: # store to info only continue if os.path.exists(te_out_npz): + # TODO check varidity of cache here continue image_infos_to_cache.append(info) @@ -1088,20 +1381,25 @@ class BaseDataset(torch.utils.data.Dataset): return # prepare tokenizers and text encoders - for text_encoder in text_encoders: + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) # create batch + is_sd3 = len(tokenizers) == 1 batch = [] batches = [] for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) - if len(batch) >= self.batch_size: + if len(batch) >= batch_size: batches.append(batch) batch = [] @@ -1110,19 +1408,48 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: call text encoder and cache outputs for memory or disk logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) def get_image_size(self, image_path): - return imagesize.get(image_path) + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use PIL as a fallback + try: + with Image.open(image_path) as img: + image_size = img.size + except Exception as e: + logger.warning(f"failed to get image size: {image_path}, error: {e}") + image_size = (0, 0) + return image_size - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1197,23 +1524,24 @@ class BaseDataset(torch.utils.data.Dataset): loss_weights = [] captions = [] input_ids_list = [] - input_ids2_list = [] latents_list = [] + alpha_mask_list = [] images = [] original_sizes_hw = [] crop_top_lefts = [] target_sizes_hw = [] flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + text_encoder_outputs_list = [] + custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False + + custom_attributes.append(subset.custom_attributes) + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1223,21 +1551,30 @@ class BaseDataset(torch.utils.data.Dataset): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) + ) if flipped: latents = flipped_latents + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem del flipped_latents latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1269,16 +1606,33 @@ class BaseDataset(torch.utils.data.Dataset): # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: - img = aug(image=img)["image"] + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) + else: + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) + else: + alpha_mask = None + + img = img[:, :, :3] # remove alpha channel + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img images.append(image) latents_list.append(latents) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1295,74 +1649,86 @@ class BaseDataset(torch.utils.data.Dataset): # captionとtext encoder outputを処理する caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( image_info.text_encoder_outputs_npz ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) + + if tokenization_required: caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + input_ids_list.append(input_ids) + captions.append(caption) + + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + + # set example example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + example["alpha_masks"] = torch.stack(alpha_mask_list) if images[0] is not None: images = torch.stack(images) @@ -1394,6 +1760,7 @@ class BaseDataset(torch.utils.data.Dataset): resized_sizes = [] bucket_reso = None flip_aug = None + alpha_mask = None random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1402,10 +1769,13 @@ class BaseDataset(torch.utils.data.Dataset): if flip_aug is None: flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask random_crop = subset.random_crop bucket_reso = image_info.bucket_reso else: + # TODO そもそも混在してても動くようにしたほうがいい assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" @@ -1442,6 +1812,7 @@ class BaseDataset(torch.utils.data.Dataset): example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask example["random_crop"] = random_crop example["bucket_reso"] = bucket_reso return example @@ -1453,10 +1824,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], - is_train: bool, batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1469,26 +1837,24 @@ class DreamBoothDataset(BaseDataset): validation_seed: Optional[int], debug_dataset, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps @@ -1528,7 +1894,7 @@ class DreamBoothDataset(BaseDataset): def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): logger.warning(f"not directory: {subset.image_dir}") - return [], [] + return [], [], [] info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info @@ -1553,10 +1919,43 @@ class DreamBoothDataset(BaseDataset): # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) sizes = [None] * len(img_paths) + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from name of cache files") + + # make image path to npz path mapping + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths.sort( + key=lambda item: item.rsplit("_", maxsplit=2)[0] + ) # sort by name excluding resolution and cache_suffix + npz_path_index = 0 + + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + l = len(os.path.splitext(img_path)[0]) # remove extension + found = False + while npz_path_index < len(npz_paths): # until found or end of npz_paths + # npz_paths are sorted, so if npz_path > img_path, img_path is not found + if npz_paths[npz_path_index][:l] > img_path[:l]: + break + if npz_paths[npz_path_index][:l] == img_path[:l]: # found + found = True + break + npz_path_index += 1 # next npz_path + + if found: + w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + else: + w, h = None, None + + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -1566,7 +1965,7 @@ class DreamBoothDataset(BaseDataset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( @@ -1651,7 +2050,8 @@ class DreamBoothDataset(BaseDataset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + logger.info(f"{num_train_images} images with repeating.") + self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") @@ -1684,8 +2084,6 @@ class FineTuningDataset(BaseDataset): self, subsets: Sequence[FineTuningSubset], batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1694,8 +2092,11 @@ class FineTuningDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1853,6 +2254,9 @@ class FineTuningDataset(BaseDataset): self.enable_bucket = enable_bucket if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps @@ -1910,8 +2314,6 @@ class ControlNetDataset(BaseDataset): subsets: Sequence[ControlNetSubset], is_train: bool, batch_size: int, - tokenizer, - max_token_length, resolution, network_multiplier: float, enable_bucket: bool, @@ -1923,7 +2325,7 @@ class ControlNetDataset(BaseDataset): validation_seed: Optional[int], debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + super().__init__(resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -1936,6 +2338,7 @@ class ControlNetDataset(BaseDataset): None, subset.caption_extension, subset.cache_info, + False, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -1961,8 +2364,6 @@ class ControlNetDataset(BaseDataset): db_subsets, is_train, batch_size, - tokenizer, - max_token_length, resolution, network_multiplier, enable_bucket, @@ -1973,7 +2374,7 @@ class ControlNetDataset(BaseDataset): 1.0, validation_split, validation_seed, - debug_dataset, + debug_dataset ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2023,6 +2424,9 @@ class ControlNetDataset(BaseDataset): self.conditioning_image_transforms = IMAGE_TRANSFORMS + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager @@ -2031,6 +2435,12 @@ class ControlNetDataset(BaseDataset): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -2074,9 +2484,7 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2116,14 +2524,27 @@ class DatasetGroup(torch.utils.data.ConcatDataset): # for dataset in self.datasets: # dataset.make_buckets() + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): for i, dataset in enumerate(self.datasets): logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + + def new_cache_latents(self, model: Any, accelerator: Accelerator): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True @@ -2132,6 +2553,21 @@ class DatasetGroup(torch.utils.data.ConcatDataset): logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size + ) + + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -2140,12 +2576,19 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.verify_bucket_reso_steps(min_steps) + def get_resolutions(self) -> List[Tuple[int, int]]: + return [(dataset.width, dataset.height) for dataset in self.datasets] + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) def is_text_encoder_output_cacheable(self) -> bool: return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -2163,53 +2606,70 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? + return False + if npz["latents"].shape[1:3] != expected_latents_size: + return False - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e return True # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") +# TODO update to use CachingStrategy +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) def debug_dataset(train_dataset, show_input_ids=False): @@ -2236,12 +2696,12 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], example["captions"], example["loss_weights"], - example["input_ids"], + # example["input_ids"], example["original_sizes_hw"], example["crop_top_lefts"], example["target_sizes_hw"], @@ -2252,12 +2712,14 @@ def debug_dataset(train_dataset, show_input_ids=False): f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] logger.info(f"image size: {im.size()}") @@ -2274,6 +2736,13 @@ def debug_dataset(train_dataset, show_input_ids=False): if os.name == "nt": cv2.imshow("cond_img", cond_img) + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() @@ -2319,8 +2788,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2382,7 +2851,7 @@ class MinimalDataset(BaseDataset): raise NotImplementedError -def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: +def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) dataset_class = args.dataset_class.split(".")[-1] module = importlib.import_module(module) @@ -2391,24 +2860,35 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img +def load_image(image_path, alpha=False): + try: + with Image.open(image_path) as image: + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + except (IOError, OSError) as e: + logger.error(f"Error loading file: {image_path}") + raise e # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] @@ -2432,8 +2912,53 @@ def trim_and_resize_if_required( return image, original_size, crop_ltrb +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2445,16 +2970,30 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + alpha_masks: List[np.ndarray] = [] for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - image = IMAGE_TRANSFORMS(image) - images.append(image) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -2468,17 +3007,26 @@ def cache_batch_latents( else: flipped_latents = [None] * len(latents) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask if not HIGH_VRAM: clean_memory_on_device(vae.device) @@ -2516,6 +3064,34 @@ def cache_batch_text_encoder_outputs( info.text_encoder_pool2 = pool2 +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, @@ -2809,7 +3385,7 @@ SS_METADATA_MINIMUM_KEYS = [ def build_minimum_network_metadata( - v2: Optional[bool], + v2: Optional[str], base_model: Optional[str], network_module: str, network_dim: str, @@ -2838,6 +3414,8 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, + flux: str = None, ): timestamp = time.time() @@ -2871,6 +3449,8 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + sd3=sd3, + flux=flux, ) return metadata @@ -2898,11 +3478,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith("%"): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + parser.add_argument( "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, " + "Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, " + "AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.", ) # backward compatibility @@ -2933,6 +3531,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', ) + # parser.add_argument( + # "--optimizer_schedulefree_wrapper", + # action="store_true", + # help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用", + # ) + + # parser.add_argument( + # "--schedulefree_wrapper_args", + # type=str, + # default=None, + # nargs="*", + # help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")', + # ) + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") parser.add_argument( "--lr_scheduler_args", @@ -2950,9 +3562,17 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( "--lr_warmup_steps", - type=int, + type=int_or_float, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, + default=0, + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -2966,6 +3586,26 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能", + ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -3097,7 +3737,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + choices=[ + "eager", + "aot_eager", + "inductor", + "aot_ts_nvfuser", + "nvprims_nvfuser", + "cudagraphs", + "ofi", + "fx2trt", + "onnxrt", + "tensort", + "ipex", + "tvm", + ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") @@ -3216,6 +3869,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") parser.add_argument( "--noise_offset", @@ -3286,8 +3940,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "smooth_l1"], - help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", ) parser.add_argument( "--huber_schedule", @@ -3301,7 +3955,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", ) parser.add_argument( @@ -3425,6 +4088,108 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): ) +def add_dit_training_arguments(parser: argparse.ArgumentParser): + # Text encoder related arguments + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + + # Model loading optimization + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # Training arguments. partial copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"], + help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior" + " / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + + # offloading + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + + +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values and convert to string if necessary + if k not in sensitive_args + sensitive_path_args: + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + + return filtered_args + + # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): # if wandb is enabled, the command line is exposed to the public @@ -3479,20 +4244,20 @@ def verify_command_line_training_args(args: argparse.Namespace): ) +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True + enable_high_vram(args) - if args.v_parameterization and not args.v2: - logger.warning( - "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" - ) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -3651,13 +4416,31 @@ def add_dataset_arguments( action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) parser.add_argument( "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--min_bucket_reso", + type=int, + default=256, + help="minimum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--max_bucket_reso", + type=int, + default=1024, + help="maximum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", + ) parser.add_argument( "--bucket_reso_steps", type=int, @@ -3682,6 +4465,11 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) parser.add_argument( "--dataset_class", @@ -3869,8 +4657,8 @@ def resume_from_local_or_hf_if_specified(accelerator, args): accelerator.load_state(dirname) -def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" +def get_optimizer(args, trainable_params) -> tuple[str, str, object]: + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -3892,6 +4680,14 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: @@ -3915,6 +4711,7 @@ def get_optimizer(args, trainable_params): lr = args.learning_rate optimizer = None + optimizer_class = None if optimizer_type == "Lion".lower(): try: @@ -3972,7 +4769,8 @@ def get_optimizer(args, trainable_params): "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") @@ -4134,26 +4932,167 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + optimizer_class = None + + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer is None: # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split(".") - optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") - optimizer_class = getattr(optimizer_module, optimizer_type) + if "." not in case_sensitive_optimizer_type: # from torch.optim + optimizer_module = torch.optim + else: # from other library + values = case_sensitive_optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + case_sensitive_optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + """ + # wrap any of above optimizer with schedulefree, if optimizer is not schedulefree + if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + + schedulefree_wrapper_kwargs = {} + if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0: + for arg in args.schedulefree_wrapper_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + schedulefree_wrapper_kwargs[key] = value + + sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs) + sf_wrapper.train() # make optimizer as train mode + + # we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper + class OptimizerProxy(torch.optim.Optimizer): + def __init__(self, sf_wrapper): + self._sf_wrapper = sf_wrapper + + def __getattr__(self, name): + return getattr(self._sf_wrapper, name) + + # override properties + @property + def state(self): + return self._sf_wrapper.state + + @state.setter + def state(self, state): + self._sf_wrapper.state = state + + @property + def param_groups(self): + return self._sf_wrapper.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self._sf_wrapper.param_groups = param_groups + + @property + def defaults(self): + return self._sf_wrapper.defaults + + @defaults.setter + def defaults(self, defaults): + self._sf_wrapper.defaults = defaults + + def add_param_group(self, param_group): + self._sf_wrapper.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self._sf_wrapper.load_state_dict(state_dict) + + def state_dict(self): + return self._sf_wrapper.state_dict() + + def zero_grad(self): + self._sf_wrapper.zero_grad() + + def step(self, closure=None): + self._sf_wrapper.step(closure) + + def train(self): + self._sf_wrapper.train() + + def eval(self): + self._sf_wrapper.eval() + + # isinstance チェックをパスするためのメソッド + def __instancecheck__(self, instance): + return isinstance(instance, (type(self), Optimizer)) + + optimizer = OptimizerProxy(sf_wrapper) + + logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}") + """ + + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + if hasattr(optimizer, "train") and callable(optimizer.train): + # make optimizer as train mode before training for schedulefree optimizer. the optimizer will be in eval mode in sampling and saving. + optimizer.train() + return optimizer_name, optimizer_args, optimizer +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler # Add some checking and features to the original function. @@ -4162,11 +5101,23 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: @@ -4202,15 +5153,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == SchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -4218,6 +5171,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") @@ -4236,7 +5192,46 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): @@ -4269,33 +5264,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): ) -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - def prepare_accelerator(args: argparse.Namespace): """ this function also prepares deepspeed plugin @@ -4335,8 +5303,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, + kwargs_handlers = [ + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), ( DistributedDataParallelKwargs( gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph @@ -4344,8 +5322,8 @@ def prepare_accelerator(args: argparse.Namespace): if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( @@ -4928,7 +5906,7 @@ def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: return typing.cast(torch.IntTensor, timesteps) -def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> float: +def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]: """ Calculate the Huber convolution (huber_c) value Huber loss is a loss function used in robust regression, that is less sensitive @@ -4940,6 +5918,8 @@ def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000) huber_c = math.exp(-alpha * timesteps.item()) elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c @@ -4949,7 +5929,7 @@ def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") elif args.loss_type == "l2": - huber_c = 1 # may be anything, as it's not used + return None else: raise NotImplementedError(f"Unknown loss type {args.loss_type}") @@ -4997,7 +5977,7 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, float]: +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]: """ Unified noise, noisy_latents, timesteps and huber loss convolution calculations """ @@ -5033,27 +6013,31 @@ def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMSched def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): """ NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already """ if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) else: - raise NotImplementedError(f"Unsupported Loss Type {loss_type}") + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") return loss @@ -5063,6 +6047,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): names.append("unet") names.append("text_encoder1") names.append("text_encoder2") + names.append("text_encoder3") # SD3 append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) @@ -5197,12 +6182,43 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + def sample_images_common( pipe_class, accelerator: Accelerator, args: argparse.Namespace, - epoch, - steps, + epoch: int, + steps: int, device, vae, tokenizer, @@ -5213,6 +6229,7 @@ def sample_images_common( ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here """ if steps == 0: @@ -5260,11 +6277,7 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) pipeline = pipe_class( text_encoder=text_encoder, @@ -5325,21 +6338,18 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, - pipeline, + pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline], save_dir, prompt_dict, epoch, @@ -5365,11 +6375,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5404,8 +6416,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] @@ -5419,17 +6432,14 @@ def sample_image_inference( img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 - try: + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption # endregion @@ -5492,6 +6502,8 @@ class LossRecorder: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/library/utils.py b/library/utils.py index 3037c055..07079c6d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,18 +1,29 @@ import logging import sys import threading -import torch -from torchvision import transforms from typing import * +import json +import struct + +import torch +import torch.nn as nn +from torchvision import transforms from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput +import cv2 +from PIL import Image +import numpy as np +from safetensors.torch import load_file def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() +# region Logging + + def add_logging_arguments(parser): parser.add_argument( "--console_log_level", @@ -79,10 +90,315 @@ def setup_logging(args=None, log_level=None, reset=False): logger.info(msg_init) +# endregion + +# region PyTorch utils + + +def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" dict[str, torch.Tensor]: + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + # logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + state_dict = load_file(path, device=device) + except: + state_dict = load_file(path) # prevent device invalid Error + if dtype is not None: + for key in state_dict.keys(): + state_dict[key] = state_dict[key].to(dtype=dtype) + return state_dict + + +# endregion + +# region Image utils + + +def pil_resize(image, size, interpolation=Image.LANCZOS): + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False + + if has_alpha: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + resized_pil = pil_image.resize(size, interpolation) + + # Convert back to cv2 format + if has_alpha: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + + return resized_cv2 + + +# endregion # TODO make inf_utils.py - - # region Gradual Latent hires fix diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c9..f8eab53b 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 65b3520c..366451b7 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -7,8 +7,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied @@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR): add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) self.cond_image = None - self.cond_emb = None def set_cond_image(self, cond_image): self.cond_image = cond_image - self.cond_emb = None def forward(self, x): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible # reshape / b,c,h,w -> b,h*w,c n, c, h, w = cx.shape @@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) cx = torch.cat([cx, self.down(x)], dim=1) cx = self.mid(cx) diff --git a/networks/convert_flux_lora.py b/networks/convert_flux_lora.py new file mode 100644 index 00000000..fe6466eb --- /dev/null +++ b/networks/convert_flux_lora.py @@ -0,0 +1,434 @@ +# convert key mapping and data format from some LoRA format to another +""" +Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module +alpha is scalar for each LoRA module + +0 to 18 +lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4]) +lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4]) +lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4]) +lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288]) +lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4]) +lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([]) +lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4]) + +0 to 37 +lora_unet_single_blocks_0_linear1.alpha torch.Size([]) +lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4]) +lora_unet_single_blocks_0_linear2.alpha torch.Size([]) +lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360]) +lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4]) +lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([]) +lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072]) +lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4]) +""" +""" +ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules. +A is down, B is up. No alpha for each LoRA module. + +0 to 18 +transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16]) +transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288]) +transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16]) +transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16]) +transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072]) +transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16]) + +0 to 37 +transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16]) +transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16]) +transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072]) +transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16]) +transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360]) +transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16]) +""" +""" +xlabs: Unknown format. +0 to 18 +double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16]) +double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16]) +double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072]) +double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16]) +""" + + +import argparse +from safetensors.torch import save_file +from safetensors import safe_open +import torch + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key): + ait_down_key = ait_key + ".lora_A.weight" + if ait_down_key not in ait_sd: + return + ait_up_key = ait_key + ".lora_B.weight" + + down_weight = ait_sd.pop(ait_down_key) + sds_sd[sds_key + ".lora_down.weight"] = down_weight + sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key) + rank = down_weight.shape[0] + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device) + + +def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys): + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + if ait_down_keys[0] not in ait_sd: + return + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + + down_weights = [ait_sd.pop(k) for k in ait_down_keys] + up_weights = [ait_sd.pop(k) for k in ait_up_keys] + + # lora_down is concatenated along dim=0, so rank is multiplied by the number of splits + rank = down_weights[0].shape[0] + num_splits = len(ait_keys) + sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0) + + merged_up_weights = torch.zeros( + (sum(w.shape[0] for w in up_weights), rank * num_splits), + dtype=up_weights[0].dtype, + device=up_weights[0].device, + ) + + i = 0 + for j, up_weight in enumerate(up_weights): + merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight + i += up_weight.shape[0] + + sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights + + # set alpha to new_rank + new_rank = rank * num_splits + sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device) + + +def convert_ai_toolkit_to_sd_scripts(ait_sd): + sds_sd = {} + for i in range(19): + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_sd_scripts_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_sd_scripts( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(ait_sd) > 0: + logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}") + return sds_sd + + +def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + + # scale weight by alpha and dim + rank = down_weight.shape[0] + alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + # print(f"rank: {rank}, alpha: {alpha}, scale: {scale}") + + # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + # print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}") + + ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down + ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up + + +def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): + if sds_key + ".lora_down.weight" not in sds_sd: + return + down_weight = sds_sd.pop(sds_key + ".lora_down.weight") + up_weight = sds_sd.pop(sds_key + ".lora_up.weight") + sd_lora_rank = down_weight.shape[0] + + # scale weight by alpha and dim + alpha = sds_sd.pop(sds_key + ".alpha") + scale = alpha / sd_lora_rank + + # calculate scale_down and scale_up + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + + down_weight = down_weight * scale_down + up_weight = up_weight * scale_up + + # calculate dims if not provided + num_splits = len(ait_keys) + if dims is None: + dims = [up_weight.shape[0] // num_splits] * num_splits + else: + assert sum(dims) == up_weight.shape[0] + + # check upweight is sparse or not + is_sparse = False + if sd_lora_rank % num_splits == 0: + ait_rank = sd_lora_rank // num_splits + is_sparse = True + i = 0 + for j in range(len(dims)): + for k in range(len(dims)): + if j == k: + continue + is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0) + i += dims[j] + if is_sparse: + logger.info(f"weight is sparse: {sds_key}") + + # make ai-toolkit weight + ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] + ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] + if not is_sparse: + # down_weight is copied to each split + ait_sd.update({k: down_weight for k in ait_down_keys}) + + # up_weight is split to each split + ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) + else: + # down_weight is chunked to each split + ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) + + # up_weight is sparse: only non-zero values are copied to each split + i = 0 + for j in range(len(dims)): + ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() + i += dims[j] + + +def convert_sd_scripts_to_ai_toolkit(sds_sd): + ait_sd = {} + for i in range(19): + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_img_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.to_q", + f"transformer.transformer_blocks.{i}.attn.to_k", + f"transformer.transformer_blocks.{i}.attn.to_v", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out" + ) + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_double_blocks_{i}_txt_attn_qkv", + [ + f"transformer.transformer_blocks.{i}.attn.add_q_proj", + f"transformer.transformer_blocks.{i}.attn.add_k_proj", + f"transformer.transformer_blocks.{i}.attn.add_v_proj", + ], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear" + ) + + for i in range(38): + convert_to_ai_toolkit_cat( + sds_sd, + ait_sd, + f"lora_unet_single_blocks_{i}_linear1", + [ + f"transformer.single_transformer_blocks.{i}.attn.to_q", + f"transformer.single_transformer_blocks.{i}.attn.to_k", + f"transformer.single_transformer_blocks.{i}.attn.to_v", + f"transformer.single_transformer_blocks.{i}.proj_mlp", + ], + dims=[3072, 3072, 3072, 12288], + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out" + ) + convert_to_ai_toolkit( + sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear" + ) + + if len(sds_sd) > 0: + logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") + return ait_sd + + +def main(args): + # load source safetensors + logger.info(f"Loading source file {args.src_path}") + state_dict = {} + with safe_open(args.src_path, framework="pt") as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + logger.info(f"Converting {args.src} to {args.dst} format") + if args.src == "ai-toolkit" and args.dst == "sd-scripts": + state_dict = convert_ai_toolkit_to_sd_scripts(state_dict) + elif args.src == "sd-scripts" and args.dst == "ai-toolkit": + state_dict = convert_sd_scripts_to_ai_toolkit(state_dict) + + # eliminate 'shared tensors' + for k in list(state_dict.keys()): + state_dict[k] = state_dict[k].detach().clone() + else: + raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported") + + # save destination safetensors + logger.info(f"Saving destination file {args.dst_path}") + save_file(state_dict, args.dst_path, metadata=metadata) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LoRA format") + parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts") + parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts") + parser.add_argument("--src_path", type=str, default=None, help="source path") + parser.add_argument("--dst_path", type=str, default=None, help="destination path") + args = parser.parse_args() + main(args) diff --git a/networks/dylora.py b/networks/dylora.py index 637f3345..b0925453 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -18,10 +18,13 @@ from transformers import CLIPTextModel import torch from torch import nn from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class DyLoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -195,7 +198,7 @@ def create_network( conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) - + if unit is not None: unit = int(unit) else: @@ -211,6 +214,16 @@ def create_network( unit=unit, varbose=True, ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module): self.alpha = alpha self.apply_to_conv = apply_to_conv + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info("create LoRA network from weights") else: @@ -320,9 +337,9 @@ class DyLoRANetwork(torch.nn.Module): lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras - + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - + self.text_encoder_loras = [] for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: @@ -331,7 +348,7 @@ class DyLoRANetwork(torch.nn.Module): else: index = None logger.info("create LoRA for Text Encoder") - + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) @@ -346,6 +363,14 @@ class DyLoRANetwork(torch.nn.Module): self.unet_loras = create_modules(True, unet, target_modules) logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -406,27 +431,53 @@ class DyLoRANetwork(torch.nn.Module): logger.info(f"weights are merged") """ + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if ratio is not None and "lora_B" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params( + self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio + ) + all_params.extend(params) return all_params diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py new file mode 100644 index 00000000..63ab2960 --- /dev/null +++ b/networks/flux_extract_lora.py @@ -0,0 +1,219 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from library import flux_utils, sai_model_spec, model_util, sdxl_model_util +import lora +from library.utils import MemoryEfficientSafeOpen +from library.utils import setup_logging +from networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + store_device = "cpu" + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as f_org: + # filter keys + keys = [] + for key in f_org.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + + with open_fn(model_tuned) as f_tuned: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = f_org.get_tensor(key) + value_t = f_tuned.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + # print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) diff --git a/networks/flux_merge_lora.py b/networks/flux_merge_lora.py new file mode 100644 index 00000000..5e100a3b --- /dev/null +++ b/networks/flux_merge_lora.py @@ -0,0 +1,765 @@ +import argparse +import math +import os +import time +from typing import Any, Dict, Union + +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm + +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import lora_flux as lora_flux +from library import sai_model_spec, train_util + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False): + if dtype is not None: + logger.info(f"converting to {dtype}...") + for key in tqdm(list(state_dict.keys())): + if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point: + state_dict[key] = state_dict[key].to(dtype) + + logger.info(f"saving to: {file_name}") + if mem_eff_save: + mem_eff_save_file(state_dict, file_name, metadata=metadata) + else: + save_file(state_dict, file_name, metadata=metadata) + + +def merge_to_flux_model( + loading_device, + working_device, + flux_path: str, + clip_l_path: str, + t5xxl_path: str, + models, + ratios, + merge_dtype, + save_dtype, + mem_eff_load_save=False, +): + # create module map without loading state_dict + lora_name_to_module_key = {} + if flux_path is not None: + logger.info(f"loading keys from FLUX.1 model: {flux_path}") + with safe_open(flux_path, framework="pt", device=loading_device) as flux_file: + keys = list(flux_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_") + lora_name_to_module_key[lora_name] = key + + lora_name_to_clip_l_key = {} + if clip_l_path is not None: + logger.info(f"loading keys from clip_l model: {clip_l_path}") + with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file: + keys = list(clip_l_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_") + lora_name_to_clip_l_key[lora_name] = key + + lora_name_to_t5xxl_key = {} + if t5xxl_path is not None: + logger.info(f"loading keys from t5xxl model: {t5xxl_path}") + with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file: + keys = list(t5xxl_file.keys()) + for key in keys: + if key.endswith(".weight"): + module_name = ".".join(key.split(".")[:-1]) + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_") + lora_name_to_t5xxl_key[lora_name] = key + + flux_state_dict = {} + clip_l_state_dict = {} + t5xxl_state_dict = {} + if mem_eff_load_save: + if flux_path is not None: + with MemoryEfficientSafeOpen(flux_path) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + + if clip_l_path is not None: + with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file: + for key in tqdm(clip_l_file.keys()): + clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device) + + if t5xxl_path is not None: + with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file: + for key in tqdm(t5xxl_file.keys()): + t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device) + else: + if flux_path is not None: + flux_state_dict = load_file(flux_path, device=loading_device) + if clip_l_path is not None: + clip_l_state_dict = load_file(clip_l_path, device=loading_device) + if t5xxl_path is not None: + t5xxl_state_dict = load_file(t5xxl_path, device=loading_device) + + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU + + logger.info(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" in key: + lora_name = key[: key.rfind(".lora_down")] + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + if lora_name in lora_name_to_module_key: + module_weight_key = lora_name_to_module_key[lora_name] + state_dict = flux_state_dict + elif lora_name in lora_name_to_clip_l_key: + module_weight_key = lora_name_to_clip_l_key[lora_name] + state_dict = clip_l_state_dict + elif lora_name in lora_name_to_t5xxl_key: + module_weight_key = lora_name_to_t5xxl_key[lora_name] + state_dict = t5xxl_state_dict + else: + logger.warning( + f"no module found for LoRA weight: {key}. Skipping..." + f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。" + ) + continue + + down_weight = lora_sd.pop(key) + up_weight = lora_sd.pop(up_key) + + dim = down_weight.size()[0] + alpha = lora_sd.pop(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = state_dict[module_weight_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + state_dict[module_weight_key] = weight.to(loading_device, save_dtype) + del up_weight + del down_weight + del weight + + if len(lora_sd) > 0: + logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}") + + return flux_state_dict, clip_l_state_dict, t5xxl_state_dict + + +def merge_to_flux_model_diffusers( + loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False +): + logger.info(f"loading keys from FLUX.1 model: {flux_model}") + if mem_eff_load_save: + flux_state_dict = {} + with MemoryEfficientSafeOpen(flux_model) as flux_file: + for key in tqdm(flux_file.keys()): + flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed + else: + flux_state_dict = load_file(flux_model, device=loading_device) + + def create_key_map(n_double_layers, n_single_layers): + key_map = {} + for index in range(n_double_layers): + prefix_from = f"transformer_blocks.{index}" + prefix_to = f"double_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv_img = f"{prefix_to}.img_attn.qkv.{end}" + qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}" + + key_map[f"{k}to_q.{end}"] = qkv_img + key_map[f"{k}to_k.{end}"] = qkv_img + key_map[f"{k}to_v.{end}"] = qkv_img + key_map[f"{k}add_q_proj.{end}"] = qkv_txt + key_map[f"{k}add_k_proj.{end}"] = qkv_txt + key_map[f"{k}add_v_proj.{end}"] = qkv_txt + + block_map = { + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + for index in range(n_single_layers): + prefix_from = f"single_transformer_blocks.{index}" + prefix_to = f"single_blocks.{index}" + + for end in ("weight", "bias"): + k = f"{prefix_from}.attn." + qkv = f"{prefix_to}.linear1.{end}" + key_map[f"{k}to_q.{end}"] = qkv + key_map[f"{k}to_k.{end}"] = qkv + key_map[f"{k}to_v.{end}"] = qkv + key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv + + block_map = { + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + "attn.norm_q.weight": "norm.query_norm.scale", + "attn.norm_k.weight": "norm.key_norm.scale", + } + + for k, v in block_map.items(): + key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}" + + # add as-is keys + values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())]) + values.sort() + key_map.update({v: v for v in values}) + + return key_map + + key_map = create_key_map(18, 38) # 18 double layers, 38 single layers + + def find_matching_key(flux_dict, lora_key): + lora_key = lora_key.replace("diffusion_model.", "") + lora_key = lora_key.replace("transformer.", "") + lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up") + lora_key = lora_key.replace("single_transformer_blocks", "single_blocks") + lora_key = lora_key.replace("transformer_blocks", "double_blocks") + + double_block_map = { + "attn.to_out.0": "img_attn.proj", + "norm1.linear": "img_mod.lin", + "norm1_context.linear": "txt_mod.lin", + "attn.to_add_out": "txt_attn.proj", + "ff.net.0.proj": "img_mlp.0", + "ff.net.2": "img_mlp.2", + "ff_context.net.0.proj": "txt_mlp.0", + "ff_context.net.2": "txt_mlp.2", + "attn.norm_q": "img_attn.norm.query_norm", + "attn.norm_k": "img_attn.norm.key_norm", + "attn.norm_added_q": "txt_attn.norm.query_norm", + "attn.norm_added_k": "txt_attn.norm.key_norm", + "attn.to_q": "img_attn.qkv", + "attn.to_k": "img_attn.qkv", + "attn.to_v": "img_attn.qkv", + "attn.add_q_proj": "txt_attn.qkv", + "attn.add_k_proj": "txt_attn.qkv", + "attn.add_v_proj": "txt_attn.qkv", + } + single_block_map = { + "norm.linear": "modulation.lin", + "proj_out": "linear2", + "attn.norm_q": "norm.query_norm", + "attn.norm_k": "norm.key_norm", + "attn.to_q": "linear1", + "attn.to_k": "linear1", + "attn.to_v": "linear1", + "proj_mlp": "linear1", + } + + # same key exists in both single_block_map and double_block_map, so we must care about single/double + # print("lora_key before double_block_map", lora_key) + for old, new in double_block_map.items(): + if "double" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key before single_block_map", lora_key) + for old, new in single_block_map.items(): + if "single" in lora_key: + lora_key = lora_key.replace(old, new) + # print("lora_key after mapping", lora_key) + + if lora_key in key_map: + flux_key = key_map[lora_key] + logger.info(f"Found matching key: {flux_key}") + return flux_key + + # If not found in key_map, try partial matching + potential_key = lora_key + ".weight" + logger.info(f"Searching for key: {potential_key}") + matches = [k for k in flux_dict.keys() if potential_key in k] + if matches: + logger.info(f"Found matching key: {matches[0]}") + return matches[0] + return None + + merged_keys = set() + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) + + logger.info("merging...") + for key in lora_sd.keys(): + if "lora_down" in key or "lora_A" in key: + lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")] + up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B") + alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha" + + logger.info(f"Processing LoRA key: {lora_name}") + flux_key = find_matching_key(flux_state_dict, lora_name) + + if flux_key is None: + logger.warning(f"no module found for LoRA weight: {key}") + continue + + logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + weight = flux_state_dict[flux_key] + + weight = weight.to(working_device, merge_dtype) + up_weight = up_weight.to(working_device, merge_dtype) + down_weight = down_weight.to(working_device, merge_dtype) + + # print(up_weight.size(), down_weight.size(), weight.size()) + + if lora_name.startswith("transformer."): + if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp + update = ratio * (up_weight @ down_weight) * scale + # print(update.shape) + + if "img_attn" in flux_key or "txt_attn" in flux_key: + q, k, v = torch.chunk(weight, 3, dim=0) + if "to_q" in lora_name or "add_q_proj" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name or "add_k_proj" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name or "add_v_proj" in lora_name: + v += update.reshape(v.shape) + weight = torch.cat([q, k, v], dim=0) + elif "linear1" in flux_key: + q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0) + mlp = weight[int(update.shape[-1] * 3) :] + # print(q.shape, k.shape, v.shape, mlp.shape) + if "to_q" in lora_name: + q += update.reshape(q.shape) + elif "to_k" in lora_name: + k += update.reshape(k.shape) + elif "to_v" in lora_name: + v += update.reshape(v.shape) + elif "proj_mlp" in lora_name: + mlp += update.reshape(mlp.shape) + weight = torch.cat([q, k, v, mlp], dim=0) + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + else: + if len(weight.size()) == 2: + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + flux_state_dict[flux_key] = weight.to(loading_device, save_dtype) + merged_keys.add(flux_key) + del up_weight + del down_weight + del weight + + logger.info(f"Merged keys: {sorted(list(merged_keys))}") + return flux_state_dict + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info("merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata + + +def merge(args): + if args.models is None: + args.models = [] + if args.ratios is None: + args.ratios = [] + + assert len(args.models) == len( + args.ratios + ), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + assert ( + args.save_to or args.clip_l_save_to or args.t5xxl_save_to + ), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください" + dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to) + if not os.path.exists(dest_dir): + logger.info(f"creating directory: {dest_dir}") + os.makedirs(dest_dir) + + if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None: + if not args.diffusers: + assert (args.clip_l is None and args.clip_l_save_to is None) or ( + args.clip_l is not None and args.clip_l_save_to is not None + ), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください" + assert (args.t5xxl is None and args.t5xxl_save_to is None) or ( + args.t5xxl is not None and args.t5xxl_save_to is not None + ), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください" + flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model( + args.loading_device, + args.working_device, + args.flux_model, + args.clip_l, + args.t5xxl, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, + ) + else: + assert ( + args.clip_l is None and args.t5xxl is None + ), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません" + flux_state_dict = merge_to_flux_model_diffusers( + args.loading_device, + args.working_device, + args.flux_model, + args.models, + args.ratios, + merge_dtype, + save_dtype, + args.mem_eff_load_save, + ) + clip_l_state_dict = None + t5xxl_state_dict = None + + if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0): + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + + if flux_state_dict is not None and len(flux_state_dict) > 0: + logger.info(f"saving FLUX model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save) + + if clip_l_state_dict is not None and len(clip_l_state_dict) > 0: + logger.info(f"saving clip_l model to: {args.clip_l_save_to}") + save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save) + + if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0: + logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}") + save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save) + + else: + flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info("calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev" + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, flux_state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, same to merging if omitted. supported types: " + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--flux_model", + type=str, + default=None, + help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする", + ) + parser.add_argument( + "--clip_l", + type=str, + default=None, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--t5xxl", + type=str, + default=None, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--loading_device", + type=str, + default="cpu", + help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます", + ) + parser.add_argument( + "--working_device", + type=str, + default="cpu", + help="device to work (merge). Merging LoRA models are done on CPU." + + " / 作業(マージ)するデバイス。LoRAモデルのマージはCPUで行われます。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--clip_l_save_to", + type=str, + default=None, + help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--t5xxl_save_to", + type=str, + default=None, + help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル", + ) + parser.add_argument( + "--models", + type=str, + nargs="*", + help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル", + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", + ) + parser.add_argument( + "--diffusers", + action="store_true", + help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/lora.py b/networks/lora.py index d1208040..6f33f1a1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,6 +12,7 @@ import numpy as np import torch import re from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel setup_logging() import logging @@ -385,14 +386,14 @@ class LoRAInfModule(LoRAModule): return out -def parse_block_lr_kwargs(nw_kwargs): +def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]: down_lr_weight = nw_kwargs.get("down_lr_weight", None) mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) up_lr_weight = nw_kwargs.get("up_lr_weight", None) # 以上のいずれにも設定がない場合は無効としてNoneを返す if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: - return None, None, None + return None # extract learning rate weight for each block if down_lr_weight is not None: @@ -401,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs): down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) + mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")] if up_lr_weight is not None: if "," in up_lr_weight: up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + return get_block_lr_weight( + is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) ) - return down_lr_weight, mid_lr_weight, up_lr_weight - def create_network( multiplier: float, @@ -424,6 +423,9 @@ def create_network( neuron_dropout: Optional[float] = None, **kwargs, ): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -441,21 +443,21 @@ def create_network( # block dim/alpha/lr block_dims = kwargs.get("block_dims", None) - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする - if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + if block_dims is not None or block_lr_weight is not None: block_alphas = kwargs.get("block_alphas", None) conv_block_dims = kwargs.get("conv_block_dims", None) conv_block_alphas = kwargs.get("conv_block_alphas", None) block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ) # remove block dim/alpha without learning rate block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight ) else: @@ -488,10 +490,20 @@ def create_network( conv_block_dims=conv_block_dims, conv_block_alphas=conv_block_alphas, varbose=True, + is_sdxl=is_sdxl, ) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) return network @@ -501,9 +513,13 @@ def create_network( # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている def get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ): - num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + if not is_sdxl: + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS + else: + # 1+9+3+9+1=23, no LoRA for emb_layers (0) + num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 def parse_ints(s): return [int(i) for i in s.split(",")] @@ -514,9 +530,10 @@ def get_block_dims_and_alphas( # block_dimsとblock_alphasをパースする。必ず値が入る if block_dims is not None: block_dims = parse_ints(block_dims) - assert ( - len(block_dims) == num_total_blocks - ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + assert len(block_dims) == num_total_blocks, ( + f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given" + + f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})" + ) else: logger.warning( f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" @@ -567,15 +584,25 @@ def get_block_dims_and_alphas( return block_dims, block_alphas, conv_block_dims, conv_block_alphas -# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく +# 戻り値は block ごとの倍率のリスト def get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold -) -> Tuple[List[float], List[float], List[float]]: + is_sdxl, + down_lr_weight: Union[str, List[float]], + mid_lr_weight: List[float], + up_lr_weight: Union[str, List[float]], + zero_threshold: float, +) -> Optional[List[float]]: # パラメータ未指定時は何もせず、今までと同じ動作とする if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return None, None, None + return None - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + if not is_sdxl: + max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS + else: + max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS def get_list(name_with_suffix) -> List[float]: import math @@ -585,15 +612,18 @@ def get_block_lr_weight( base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + return [ + math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr + for i in reversed(range(max_len_for_down_or_up)) + ] elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)] elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] + return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)] elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))] elif name == "zeros": - return [0.0 + base_lr] * max_len + return [0.0 + base_lr] * max_len_for_down_or_up else: logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" @@ -606,20 +636,36 @@ def get_block_lr_weight( if type(up_lr_weight) == str: up_lr_weight = get_list(up_lr_weight) - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] + if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up) + up_lr_weight = up_lr_weight[:max_len_for_down_or_up] + down_lr_weight = down_lr_weight[:max_len_for_down_or_up] - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid: + logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid) + logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight[:max_len_for_mid] - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up) + logger.warning( + "down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up + ) + + if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up: + down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up: + up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight)) + + if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid: + logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid) + logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): logger.info("apply block learning rate / 階層別学習率を適用します。") @@ -627,78 +673,139 @@ def get_block_lr_weight( down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: + down_lr_weight = [1.0] * max_len_for_down_or_up logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: - mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight] logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - logger.info("mid_lr_weight: 1.0") + mid_lr_weight = [1.0] * max_len_for_mid + logger.info("mid_lr_weight: all 1.0, すべて1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: + up_lr_weight = [1.0] * max_len_for_down_or_up logger.info("up_lr_weight: all 1.0, すべて1.0") - return down_lr_weight, mid_lr_weight, up_lr_weight + lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight + + if is_sdxl: + lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out + + assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or ( + is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + ), f"lr_weight length is invalid: {len(lr_weight)}" + + return lr_weight # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく def remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]] ): - # set 0 to block dim without learning rate to remove the block - if down_lr_weight != None: - for i, lr in enumerate(down_lr_weight): + if block_lr_weight is not None: + for i, lr in enumerate(block_lr_weight): if lr == 0: block_dims[i] = 0 if conv_block_dims is not None: conv_block_dims[i] = 0 - if mid_lr_weight != None: - if mid_lr_weight == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if up_lr_weight != None: - for i, lr in enumerate(up_lr_weight): - if lr == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - return block_dims, block_alphas, conv_block_dims, conv_block_alphas # 外部から呼び出す可能性を考慮しておく -def get_block_index(lora_name: str) -> int: +def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: block_idx = -1 # invalid lora name + if not is_sdxl: + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 - m = RE_UPDOWN.search(lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + else: + # copy from sdxl_train + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA + block_idx = 0 # 0 + elif name.startswith("input_blocks_"): # 1-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10-12 + block_idx = 10 + int(name.split("_")[2]) + elif name.startswith("output_blocks_"): # 13-21 + block_idx = 13 + int(name.split("_")[2]) + elif name.startswith("out_"): # 22, out, no LoRA + block_idx = 22 return block_idx +def convert_diffusers_to_sai_if_needed(weights_sd): + # only supports U-Net LoRA modules + + found_up_down_blocks = False + for k in list(weights_sd.keys()): + if "down_blocks" in k: + found_up_down_blocks = True + break + if "up_blocks" in k: + found_up_down_blocks = True + break + if not found_up_down_blocks: + return + + from library.sdxl_model_util import make_unet_conversion_map + + unet_conversion_map = make_unet_conversion_map() + unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + + # # add extra conversion + # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv" + + logger.info(f"Converting LoRA keys from Diffusers to SAI") + lora_unet_prefix = "lora_unet_" + for k in list(weights_sd.keys()): + if not k.startswith(lora_unet_prefix): + continue + + unet_module_name = k[len(lora_unet_prefix) :].split(".")[0] + + # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small + for hf_module_name, sd_module_name in unet_conversion_map.items(): + if hf_module_name in unet_module_name: + new_key = ( + lora_unet_prefix + + unet_module_name.replace(hf_module_name, sd_module_name) + + k[len(lora_unet_prefix) + len(unet_module_name) :] + ) + weights_sd[new_key] = weights_sd.pop(k) + found = True + break + + if not found: + logger.warning(f"Key {k} is not found in unet_conversion_map") + + # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -707,6 +814,10 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: weights_sd = torch.load(file, map_location="cpu") + # if keys are Diffusers based, convert to SAI based + if is_sdxl: + convert_diffusers_to_sai_if_needed(weights_sd) + # get dim/alpha mapping modules_dim = {} modules_alpha = {} @@ -730,19 +841,28 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoder, + unet, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + is_sdxl=is_sdxl, ) # block lr - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) return network, weights_sd class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + NUM_OF_MID_BLOCKS = 1 + SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23 + SDXL_NUM_OF_MID_BLOCKS = 3 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] @@ -774,6 +894,7 @@ class LoRANetwork(torch.nn.Module): modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, + is_sdxl: Optional[bool] = False, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -794,6 +915,10 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: @@ -855,7 +980,7 @@ class LoRANetwork(torch.nn.Module): alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: # U-Netでblock_dims指定あり - block_idx = get_block_index(lora_name) + block_idx = get_block_index(lora_name, is_sdxl) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] alpha = block_alphas[block_idx] @@ -925,9 +1050,7 @@ class LoRANetwork(torch.nn.Module): for name in skipped: logger.info(f"\t{name}") - self.up_lr_weight: List[float] = None - self.down_lr_weight: List[float] = None - self.mid_lr_weight: float = None + self.block_lr_weight = None self.block_lr = False # assertion @@ -958,12 +1081,12 @@ class LoRANetwork(torch.nn.Module): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - logger.info("enable LoRA for text encoder") + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") else: self.text_encoder_loras = [] if apply_unet: - logger.info("enable LoRA for U-Net") + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") else: self.unet_loras = [] @@ -1004,81 +1127,117 @@ class LoRANetwork(torch.nn.Module): logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない - def set_block_lr_weight( - self, - up_lr_weight: List[float] = None, - mid_lr_weight: float = None, - down_lr_weight: List[float] = None, - ): + def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]): self.block_lr = True - self.down_lr_weight = down_lr_weight - self.mid_lr_weight = mid_lr_weight - self.up_lr_weight = up_lr_weight + self.block_lr_weight = block_lr_weight - def get_lr_weight(self, lora: LoRAModule) -> float: - lr_weight = 1.0 - block_idx = get_block_index(lora.lora_name) - if block_idx < 0: - return lr_weight + def get_lr_weight(self, block_idx: int) -> float: + if not self.block_lr or self.block_lr_weight is None: + return 1.0 + return self.block_lr_weight[block_idx] - if block_idx < LoRANetwork.NUM_OF_BLOCKS: - if self.down_lr_weight != None: - lr_weight = self.down_lr_weight[block_idx] - elif block_idx == LoRANetwork.NUM_OF_BLOCKS: - if self.mid_lr_weight != None: - lr_weight = self.mid_lr_weight - elif block_idx > LoRANetwork.NUM_OF_BLOCKS: - if self.up_lr_weight != None: - lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio - return lr_weight + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - self.requires_grad_(True) - all_params = [] + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" - def enumerate_params(loras): - params = [] + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) - return params + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params, descriptions = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) if self.unet_loras: if self.block_lr: + is_sdxl = False + for lora in self.unet_loras: + if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + is_sdxl = True + break + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 block_idx_to_lora = {} for lora in self.unet_loras: - idx = get_block_index(lora.lora_name) + idx = get_block_index(lora.lora_name, is_sdxl) if idx not in block_idx_to_lora: block_idx_to_lora[idx] = [] block_idx_to_lora[idx].append(lora) # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) - elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params, descriptions = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) - return all_params + return all_params, lr_descriptions def enable_gradient_checkpointing(self): # not supported diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 00000000..91e9cd77 --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,1157 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_flux and type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 00000000..ce6d1a16 --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,839 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import numpy as np +import torch +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from networks.lora_flux import LoRAModule, LoRAInfModule +from library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/oft.py b/networks/oft.py index 461a9869..0c3a5393 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -4,13 +4,17 @@ import math import os from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL +import einops from transformers import CLIPTextModel import numpy as np import torch +import torch.nn.functional as F import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module): if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - self.constraint = alpha * out_dim + + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu self.out_dim = out_dim self.shape = org_module.weight.shape @@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I - R = torch.block_diag(*block_R_weighted) - - return R + if self.I.device != block_Q.device: + self.I = self.I.to(block_Q.device) + I = self.I + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + return block_R_weighted def forward(self, x, scale=None): - x = self.org_forward(x) if self.multiplier == 0.0: - return x + return self.org_forward(x) + org_module = self.org_module[0] + org_dtype = x.dtype - R = self.get_weight().to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + R = self.get_weight().to(torch.float32) + W = org_module.weight.to(torch.float32) + + if len(W.shape) == 4: # Conv2d + W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped) + RW = einops.rearrange(RW, "k m ... -> (k m) ...") + result = F.conv2d( + x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups + ) + else: # Linear + W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped) + RW = einops.rearrange(RW, "k m p -> (k m) p") + result = F.linear(x, RW.to(org_dtype), org_module.bias) + return result class OFTInfModule(OFTModule): @@ -115,18 +133,19 @@ class OFTInfModule(OFTModule): return self.org_forward(x) return super().forward(x, scale) - def merge_to(self, multiplier=None, sign=1): - R = self.get_weight(multiplier) * sign - + def merge_to(self, multiplier=None): # get org weight org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - R = R.to(org_weight.device, dtype=org_weight.dtype) + org_weight = org_sd["weight"].to(torch.float32) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + R = self.get_weight(multiplier).to(torch.float32) + + weight = org_weight.reshape(self.num_blocks, self.block_size, -1) + weight = torch.einsum("k n m, k n ... -> k m ...", R, weight) + weight = weight.reshape(org_weight.shape) + + # convert back to original dtype + weight = weight.to(org_sd["weight"].dtype) # set weight to org_module org_sd["weight"] = weight @@ -145,8 +164,16 @@ def create_network( ): if network_dim is None: network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) enable_all_linear = kwargs.get("enable_all_linear", None) enable_conv = kwargs.get("enable_conv", None) @@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: if dim is None: dim = param.size()[0] - if has_conv2d is None and param.dim() == 4: + if has_conv2d is None and "in_layers_2" in name: has_conv2d = True - if all_linear is None: - if param.dim() == 3 and "attn" not in name: - all_linear = True - if dim is not None and alpha is not None and has_conv2d is not None: + if all_linear is None and "_ff_" in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None: break if has_conv2d is None: has_conv2d = False @@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module): self.alpha = alpha logger.info( - f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}" ) # create module instances diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 00000000..27b8b637 --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False diff --git a/networks/resize_lora.py b/networks/resize_lora.py index d697baa4..7df7ef0c 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if model_util.is_safetensors(file_name): save_file(state_dict, file_name, metadata) else: @@ -349,12 +344,18 @@ def resize(args): metadata["ss_network_dim"] = "Dynamic" metadata["ss_network_alpha"] = "Dynamic" + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 3383a80d..b147eb44 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -1,18 +1,25 @@ +import itertools import math import argparse import os import time +import concurrent.futures import torch from safetensors.torch import load_file, save_file from tqdm import tqdm from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft +from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26 from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -28,36 +35,58 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, model, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(model, file_name, metadata=metadata) else: torch.save(model, file_name) -def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): - text_encoder1.to(merge_dtype) +def detect_method_from_training_model(models, dtype): + for model in models: + # TODO It is better to use key names to detect the method + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if "lora_up" in key or "lora_down" in key: + return "LoRA" + elif "oft_blocks" in key: + return "OFT" + + +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype): text_encoder1.to(merge_dtype) + text_encoder2.to(merge_dtype) unet.to(merge_dtype) + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + logger.info(f"method:{method}") + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == "LoRA": + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == "OFT": + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + # ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): @@ -68,65 +97,172 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + + if method == "LoRA": + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + elif method == "OFT": + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue + logger.info(f"no module found for OFT weight: {key}") + return module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] + oft_blocks = lora_sd[key] - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + + multiplier = 1 + if lbw: + index = get_lbw_block_index(key, False) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + multiplier *= lbw_weights[index] + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) + + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor module.weight = torch.nn.Parameter(weight) + # TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough + max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + +def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} + # detect the method: OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + if method == "OFT": + raise ValueError( + "OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません" + ) + + if lbws: + lbws, _, LBW_TARGET_IDX = format_lbws(lbws) + else: + LBW_TARGET_IDX = [] + merged_sd = {} v2 = None base_model = None - for model, ratio in zip(models, ratios): + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + if lora_metadata is not None: if v2 is None: v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず @@ -164,7 +300,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue - + if "lora_up" in key and concat: concat_dim = 1 elif "lora_down" in key and concat: @@ -178,8 +314,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if lbw: + index = get_lbw_block_index(key, True) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if key in merged_sd: assert ( merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None @@ -201,7 +343,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): dim = merged_sd[key_down].shape[0] perm = torch.randperm(dim) merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] + merged_sd[key_up] = merged_sd[key_up][:, perm] logger.info("merged model") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -229,7 +371,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -257,7 +407,7 @@ def merge(args): ckpt_info, ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") - merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) + merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype) if args.no_metadata: sai_metadata = None @@ -273,7 +423,13 @@ def merge(args): args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle) + + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) logger.info(f"calculating hashes and creating metadata...") @@ -290,7 +446,7 @@ def merge(args): metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -316,12 +472,19 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument( "--no_metadata", action="store_true", @@ -337,8 +500,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", + help="shuffle lora weight./ " + "LoRAの重みをシャッフルする", ) return parser diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index cb00a600..c79b45ac 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,5 +1,8 @@ import argparse +import itertools +import json import os +import re import time import torch from safetensors.torch import load_file, save_file @@ -8,12 +11,195 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 +ACCEPTABLE = [12, 17, 20, 26] +SDXL_LAYER_NUM = [12, 20] + +LAYER12 = { + "BASE": True, + "IN00": False, + "IN01": False, + "IN02": False, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, + "MID": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": False, + "OUT07": False, + "OUT08": False, + "OUT09": False, + "OUT10": False, + "OUT11": False, +} + +LAYER17 = { + "BASE": True, + "IN00": False, + "IN01": True, + "IN02": True, + "IN03": False, + "IN04": True, + "IN05": True, + "IN06": False, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, + "MID": True, + "OUT00": False, + "OUT01": False, + "OUT02": False, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, +} + +LAYER20 = { + "BASE": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": False, + "IN10": False, + "IN11": False, + "MID": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": False, + "OUT10": False, + "OUT11": False, +} + +LAYER26 = { + "BASE": True, + "IN00": True, + "IN01": True, + "IN02": True, + "IN03": True, + "IN04": True, + "IN05": True, + "IN06": True, + "IN07": True, + "IN08": True, + "IN09": True, + "IN10": True, + "IN11": True, + "MID": True, + "OUT00": True, + "OUT01": True, + "OUT02": True, + "OUT03": True, + "OUT04": True, + "OUT05": True, + "OUT06": True, + "OUT07": True, + "OUT08": True, + "OUT09": True, + "OUT10": True, + "OUT11": True, +} + +assert len([v for v in LAYER12.values() if v]) == 12 +assert len([v for v in LAYER17.values() if v]) == 17 +assert len([v for v in LAYER20.values() if v]) == 20 +assert len([v for v in LAYER26.values() if v]) == 26 + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int: + # lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder + if "text_model_encoder_" in lora_name: # LoRA for text encoder + return 0 + + # lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2 + block_idx = -1 # invalid lora name + if not is_sdxl: + NUM_OF_BLOCKS = 12 # up/down blocks + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + up_down = g[0] + i = int(g[1]) + j = int(g[3]) + if up_down == "down": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + 1 + elif g[2] == "downsamplers": + idx = 3 * (i + 1) + else: + return block_idx # invalid lora name + elif up_down == "up": + if g[2] == "resnets" or g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers": + idx = 3 * i + 2 + else: + return block_idx # invalid lora name + + if g[0] == "down": + block_idx = 1 + idx # 1-based index, down block index + elif g[0] == "up": + block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index + + elif "mid_block_" in lora_name: + block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block + else: + # SDXL: some numbers are skipped + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts + block_idx = 1 + elif name.startswith("input_blocks_"): # 1-8 to 2-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 13 + block_idx = 13 + elif name.startswith("output_blocks_"): # 0-8 to 14-22 + block_idx = 14 + int(name.split("_")[2]) + elif name.startswith("out_"): # 23, No LoRA in sd-scripts + block_idx = 23 + + return block_idx + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -30,24 +216,53 @@ def load_state_dict(file_name, dtype): return sd, metadata -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - +def save_to_file(file_name, state_dict, metadata): if os.path.splitext(file_name)[1] == ".safetensors": save_file(state_dict, file_name, metadata=metadata) else: torch.save(state_dict, file_name) -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): +def format_lbws(lbws): + try: + # lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している + lbws = [json.loads(lbw) for lbw in lbws] + except Exception: + raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください") + assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください" + assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください" + assert all( + len(lbw) in ACCEPTABLE for lbw in lbws + ), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください" + assert all( + all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws + ), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください" + + layer_num = len(lbws[0]) + is_sdxl = True if layer_num in SDXL_LAYER_NUM else False + FLAGS = { + "12": LAYER12.values(), + "17": LAYER17.values(), + "20": LAYER20.values(), + "26": LAYER26.values(), + }[str(layer_num)] + LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag] + return lbws, is_sdxl, LBW_TARGET_IDX + + +def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype): logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} - v2 = None + v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2 base_model = None - for model, ratio in zip(models, ratios): + + if lbws: + lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws) + else: + is_sdxl = False + LBW_TARGET_IDX = [] + + for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws): logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) @@ -57,6 +272,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty if base_model is None: base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + if lbw: + lbw_weights = [1] * 26 + for index, value in zip(LBW_TARGET_IDX, lbw): + lbw_weights[index] = value + logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}") + # merge logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): @@ -80,10 +301,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # make original weight if not exist if lora_module_name not in merged_sd: weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) - if device: - weight = weight.to(device) else: weight = merged_sd[lora_module_name] + if device: + weight = weight.to(device) # merge to weight if device: @@ -93,6 +314,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty # W <- W + U * D scale = alpha / network_dim + if lbw: + index = get_lbw_block_index(key, is_sdxl) + is_lbw_target = index in LBW_TARGET_IDX + if is_lbw_target: + scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける + if device: # and isinstance(scale, torch.Tensor): scale = scale.to(device) @@ -109,13 +336,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = weight + ratio * conved * scale - merged_sd[lora_module_name] = weight + merged_sd[lora_module_name] = weight.to("cpu") # extract from merged weights logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): + if device: + mat = mat.to(device) + conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) @@ -154,7 +384,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) + merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu") # build minimum metadata dims = f"{new_rank}" @@ -169,7 +399,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + assert len(args.models) == len( + args.ratios + ), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + if args.lbws: + assert len(args.models) == len( + args.lbws + ), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください" + else: + args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく def str_to_dtype(p): if p == "float": @@ -187,9 +425,15 @@ def merge(args): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype ) + # cast to save_dtype before calculating hashes + for key in list(state_dict.keys()): + value = state_dict[key] + if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype: + state_dict[key] = value.to(save_dtype) + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) @@ -211,7 +455,7 @@ def merge(args): metadata.update(sai_metadata) logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) + save_to_file(args.save_to, state_dict, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -231,12 +475,19 @@ def setup_parser() -> argparse.ArgumentParser: help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + "--models", + type=str, + nargs="*", + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors", ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率") parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") parser.add_argument( "--new_conv_rank", @@ -244,7 +495,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) parser.add_argument( "--no_metadata", action="store_true", diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..484d3aef --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +minversion = 6.0 +testpaths = + tests +filterwarnings = + ignore::DeprecationWarning + ignore::UserWarning + ignore::FutureWarning diff --git a/requirements.txt b/requirements.txt index e99775b8..e0091749 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,23 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.33.0 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 +opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 +schedulefree==1.4 tensorboard -safetensors==0.4.2 +safetensors==0.4.4 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning @@ -38,5 +39,7 @@ imagesize==1.4.1 # open-clip-torch==2.20.0 # For logging rich==13.7.0 +# for T5XXL tokenizer (SD3/FLUX) +sentencepiece==0.2.0 # for kohya_ss library -e . diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py new file mode 100644 index 00000000..86dba246 --- /dev/null +++ b/sd3_minimal_inference.py @@ -0,0 +1,407 @@ +# Minimum Inference Code for SD3 + +import argparse +import datetime +import math +import os +import random +from typing import Optional, Tuple +import numpy as np + +import torch +from safetensors.torch import safe_open, load_file +import torch.amp +from tqdm import tqdm +from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel + +from library.device_utils import init_ipex, get_preferred_device +from networks import lora_sd3 + +init_ipex() + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from library import sd3_models, sd3_utils, strategy_sd3 +from library.utils import load_safetensors + + +def get_noise(seed, latent, device="cpu"): + # generator = torch.manual_seed(seed) + generator = torch.Generator(device) + generator.manual_seed(seed) + return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device) + + +def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + initial_latent: Optional[torch.Tensor], + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + cfg_scale: float, + dtype: torch.dtype, + device: str, +): + if initial_latent is None: + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + else: + latent = initial_latent + + latent = latent.to(dtype).to(device) + + noise = get_noise(seed, latent, device) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_sigmas(model_sampling, steps).to(device) + # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i + + # conditioning = fix_cond(conditioning) + # neg_cond = fix_cond(neg_cond) + # extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale} + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + with torch.no_grad(): + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + with torch.autocast(device_type=device.type, dtype=dtype): + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * cfg_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + + latent = x + latent = vae.process_out(latent) + return latent + + +def generate_image( + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, + clip_l: CLIPTextModelWithProjection, + clip_g: CLIPTextModelWithProjection, + t5xxl: T5EncoderModel, + steps: int, + prompt: str, + seed: int, + target_width: int, + target_height: int, + device: str, + negative_prompt: str, + cfg_scale: float, +): + # prepare embeddings + logger.info("Encoding prompts...") + + # TODO support one-by-one offloading + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + + with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad(): + tokens_and_masks = tokenize_strategy.tokenize(prompt) + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + tokens_and_masks = tokenize_strategy.tokenize(negative_prompt) + lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask + ) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # attn masks are not used currently + + if args.offload: + clip_l.to("cpu") + clip_g.to("cpu") + t5xxl.to("cpu") + + # generate image + logger.info("Generating image...") + mmdit.to(device) + latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device) + if args.offload: + mmdit.to("cpu") + + # latent to image + vae.to(device) + with torch.no_grad(): + image = vae.decode(latent_sampled) + + if args.offload: + vae.to("cpu") + + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + out_image = Image.fromarray(decoded_np) + + # save image + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") + out_image.save(output_path) + + logger.info(f"Saved image to {output_path}") + + +if __name__ == "__main__": + target_height = 1024 + target_width = 1024 + + # steps = 50 # 28 # 50 + # cfg_scale = 5 + # seed = 1 # None # 1 + + device = get_preferred_device() + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--clip_g", type=str, required=False) + parser.add_argument("--clip_l", type=str, required=False) + parser.add_argument("--t5xxl", type=str, required=False) + parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256") + parser.add_argument("--apply_lg_attn_mask", action="store_true") + parser.add_argument("--apply_t5_attn_mask", action="store_true") + parser.add_argument("--prompt", type=str, default="A photo of a cat") + # parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--cfg_scale", type=float, default=5.0) + parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument("--output_dir", type=str, default=".") + # parser.add_argument("--do_not_use_t5xxl", action="store_true") + # parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch") + parser.add_argument("--fp16", action="store_true") + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--steps", type=int, default=50) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") + parser.add_argument("--width", type=int, default=target_width) + parser.add_argument("--height", type=int, default=target_height) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + seed = args.seed + steps = args.steps + + sd3_dtype = torch.float32 + if args.fp16: + sd3_dtype = torch.float16 + elif args.bf16: + sd3_dtype = torch.bfloat16 + + loading_device = "cpu" if args.offload else device + + # load state dict + logger.info(f"Loading SD3 models from {args.ckpt_path}...") + # state_dict = load_file(args.ckpt_path) + state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype) + + # load text encoders + clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict) + + # MMDiT and VAE + vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict) + mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device) + + clip_l.to(sd3_dtype) + clip_g.to(sd3_dtype) + t5xxl.to(sd3_dtype) + vae.to(sd3_dtype) + mmdit.to(sd3_dtype) + if not args.offload: + # make sure to move to the device: some tensors are created in the constructor on the CPU + clip_l.to(device) + clip_g.to(device) + t5xxl.to(device) + vae.to(device) + mmdit.to(device) + + clip_l.eval() + clip_g.eval() + t5xxl.eval() + mmdit.eval() + vae.eval() + + # load tokenizers + logger.info("Loading tokenizers...") + tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) + encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + + # LoRA + lora_models: list[lora_sd3.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + module = lora_sd3 + lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd) + else: + lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + + if not args.interactive: + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + args.steps, + args.prompt, + args.seed, + args.width, + args.height, + device, + args.negative_prompt, + args.cfg_scale, + ) + else: + # loop for interactive + width = args.width + height = args.height + steps = None + cfg_scale = args.cfg_scale + + while True: + print( + "Enter prompt (empty to exit). Options: --w --h --s --d " + " --n , `--n -` for empty negative prompt" + "Options are kept for the next prompt. Current options:" + f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}" + ) + prompt = input() + if prompt == "": + break + + # parse options + options = prompt.split("--") + prompt = options[0].strip() + seed = None + negative_prompt = None + for opt in options[1:]: + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") + + generate_image( + mmdit, + vae, + clip_l, + clip_g, + t5xxl, + steps if steps is not None else args.steps, + prompt, + seed if seed is not None else args.seed, + width, + height, + device, + negative_prompt if negative_prompt is not None else args.negative_prompt, + cfg_scale, + ) + + logger.info("Done!") diff --git a/sd3_train.py b/sd3_train.py new file mode 100644 index 00000000..120455e7 --- /dev/null +++ b/sd3_train.py @@ -0,0 +1,1073 @@ +# training with captions + +import argparse +from concurrent.futures import ThreadPoolExecutor +import copy +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library import utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3 +from library.sdxl_train_util import match_mixed_precision + +# , sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util + +# import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + +# from library.custom_train_functions import ( +# apply_snr_weight, +# prepare_scheduler_for_custom_training, +# scale_v_prediction_loss_like_noise_prediction, +# add_v_prediction_like_loss, +# apply_debiased_estimation, +# apply_masked_loss, +# ) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + # assert ( + # not args.train_text_encoder or not args.cache_text_encoder_outputs + # ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), ( + "when training text encoder, text encoder outputs must not be cached (except for T5XXL)" + + " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)" + ) + + if args.use_t5xxl_cache_only and not args.cache_text_encoder_outputs: + logger.warning( + "use_t5xxl_cache_only is enabled, so cache_text_encoder_outputs is automatically enabled." + + " / use_t5xxl_cache_onlyが有効なため、cache_text_encoder_outputsも自動的に有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.train_t5xxl: + assert ( + args.train_text_encoder + ), "when training T5XXL, text encoder (CLIP-L/G) must be trained / T5XXLを学習するときはtext encoder (CLIP-L/G)も学習する必要があります" + assert ( + not args.cache_text_encoder_outputs + ), "when training T5XXL, t5xxl output must not be cached / T5XXLを学習するときはt5xxlの出力をキャッシュできません" + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(8) # TODO これでいいか確認 + + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + False, + False, + False, + False, + ) + ) + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + + # t5xxl_dtype = weight_dtype + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + if args.clip_l is None: + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + else: + sd3_state_dict = None + + # load tokenizer and prepare tokenize strategy + sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy) + + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + # clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_l = sd3_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + clip_g = sd3_utils.load_clip_g(args.clip_g, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + t5xxl = sd3_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + assert clip_l is not None and clip_g is not None and t5xxl is not None, "clip_l, clip_g, t5xxl must be specified" + + # prepare text encoding strategy + text_encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, args.t5_dropout_rate + ) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # 学習を準備する:モデルを適切な状態にする + train_clip = False + train_t5xxl = False + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + clip_l.gradient_checkpointing_enable() + clip_g.gradient_checkpointing_enable() + if args.train_t5xxl: + t5xxl.gradient_checkpointing_enable() + + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + lr_t5xxl = args.learning_rate_te3 if args.learning_rate_te3 is not None else args.learning_rate # 0 means not train + train_clip = lr_te1 != 0 or lr_te2 != 0 + train_t5xxl = lr_t5xxl != 0 and args.train_t5xxl + + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(train_clip) + clip_g.requires_grad_(train_clip) + t5xxl.requires_grad_(train_t5xxl) + else: + print("disable text encoder training") + clip_l.to(weight_dtype) + clip_g.to(weight_dtype) + t5xxl.to(weight_dtype) + clip_l.requires_grad_(False) + clip_g.requires_grad_(False) + t5xxl.requires_grad_(False) + lr_te1 = 0 + lr_te2 = 0 + lr_t5xxl = 0 + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + t5xxl.to(accelerator.device) + clip_l.eval() + clip_g.eval() + t5xxl.eval() + + text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = sd3_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, + [clip_l, clip_g, t5xxl], + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + enable_dropout=False, + ) + + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + if not args.use_t5xxl_cache_only: + clip_l = None + clip_g = None + t5xxl = None + + clean_memory_on_device(accelerator.device) + + # load VAE for caching latents + if sd3_state_dict is None: + logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}") + sd3_state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype + ) + + vae = sd3_utils.load_vae(args.vae, weight_dtype, "cpu", args.disable_mmap_load_safetensors, state_dict=sd3_state_dict) + if cache_latents: + # vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # load MMDIT + mmdit = sd3_utils.load_mmdit(sd3_state_dict, model_dtype, "cpu") + + # attn_mode = "xformers" if args.xformers else "torch" + # assert ( + # attn_mode == "torch" + # ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + resolutions = train_dataset_group.get_resolutions() + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared + + # block swap + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # move to accelerator device + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared + + if args.num_last_block_to_freeze: + # freeze last n blocks of MM-DIT + block_name = "x_block" + filtered_blocks = [(name, param) for name, param in mmdit.named_parameters() if block_name in name] + accelerator.print(f"filtered_blocks: {len(filtered_blocks)}") + + num_blocks_to_freeze = min(len(filtered_blocks), args.num_last_block_to_freeze) + + accelerator.print(f"freeze_blocks: {num_blocks_to_freeze}") + + start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze) + + for i in range(start_freezing_from, len(filtered_blocks)): + _, param = filtered_blocks[i] + param.requires_grad = False + + training_models = [] + params_to_optimize = [] + param_names = [] + training_models.append(mmdit) + params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate}) + param_names.append([n for n, _ in mmdit.named_parameters()]) + + if train_clip: + if lr_te1 > 0: + training_models.append(clip_l) + params_to_optimize.append({"params": list(clip_l.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + param_names.append([n for n, _ in clip_l.named_parameters()]) + if lr_te2 > 0: + training_models.append(clip_g) + params_to_optimize.append({"params": list(clip_g.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + param_names.append([n for n, _ in clip_g.named_parameters()]) + if train_t5xxl: + training_models.append(t5xxl) + params_to_optimize.append({"params": list(t5xxl.parameters()), "lr": args.learning_rate_te3 or args.learning_rate}) + param_names.append([n for n, _ in t5xxl.named_parameters()]) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train mmdit: {train_mmdit} , clip:{train_clip}, t5xxl:{train_t5xxl}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups for mmdit. clip_l, clip_g, t5xxl are in each group + grouped_params = [] + param_group = {} + group = params_to_optimize[0] + named_parameters = list(mmdit.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # joint or other + if np[0].startswith("joint_blocks"): + block_idx = int(np[0].split(".")[1]) + block_type = "joint" + else: + block_idx = -1 + + param_group_key = (block_type, block_idx) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + grouped_params.extend(params_to_optimize[1:]) # add clip_l, clip_g, t5xxl if they are trained + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + optimizer_train_fn = lambda: None # dummy function + optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + mmdit.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + mmdit.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + if clip_g is not None: + clip_g.to(weight_dtype) + if t5xxl is not None: + t5xxl.to(weight_dtype) + + # TODO check if this is necessary. SD3 uses pool for clip_l and clip_g + # # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + # if train_clip_l: + # clip_l.text_model.encoder.layers[-1].requires_grad_(False) + # clip_l.text_model.final_layer_norm.requires_grad_(False) + + # move Text Encoders to GPU if not caching outputs + if not args.cache_text_encoder_outputs: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, mmdit=mmdit, clip_l=clip_l if train_clip else None, clip_g=clip_g if train_clip else None + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_mmdit: + mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + if train_clip: + clip_l = accelerator.prepare(clip_l) + clip_g = accelerator.prepare(clip_g) + if train_t5xxl: + t5xxl = accelerator.prepare(t5xxl) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # only used to get timesteps, etc. TODO manage timesteps etc. separately + dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + # For --sample_at_first + optimizer_eval_fn() + sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) + optimizer_train_fn() + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + # show model device and dtype + logger.info( + f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}" + if mmdit + else "mmdit is None" + ) + logger.info( + f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}" + if clip_l + else "clip_l is None" + ) + logger.info( + f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}" + if clip_g + else "clip_g is None" + ) + logger.info( + f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}" + if t5xxl + else "t5xxl is None" + ) + logger.info( + f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}" + if vae is not None + else "vae is None" + ) + + loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to( + accelerator.device, dtype=weight_dtype + ) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + # latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + latents = sd3_models.SDVAE.process_in(latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list + if args.use_t5xxl_cache_only: + lg_out = None + lg_pooled = None + else: + lg_out = None + t5_out = None + lg_pooled = None + l_attn_mask = None + g_attn_mask = None + t5_attn_mask = None + + if lg_out is None: + # not cached or training, so get from text encoders + input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] + with torch.set_grad_enabled(train_clip): + # TODO support weighted captions + # text models in sd3_models require "cpu" for input_ids + input_ids_clip_l = input_ids_clip_l.to("cpu") + input_ids_clip_g = input_ids_clip_g.to("cpu") + lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, + [clip_l, clip_g, None], + [input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None], + ) + + if t5_out is None: + _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] + with torch.set_grad_enabled(train_t5xxl): + input_ids_t5xxl = input_ids_t5xxl.to("cpu") + _, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens( + sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] + ) + + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + # bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, latents, noise, accelerator.device, weight_dtype + ) + + # debug: NaN check for all inputs + if torch.any(torch.isnan(noisy_model_input)): + accelerator.print("NaN found in noisy_model_input, replacing with zeros") + noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + if torch.any(torch.isnan(context)): + accelerator.print("NaN found in context, replacing with zeros") + context = torch.nan_to_num(context, 0, out=context) + if torch.any(torch.isnan(lg_pooled)): + accelerator.print("NaN found in pool, replacing with zeros") + lg_pooled = torch.nan_to_num(lg_pooled, 0, out=lg_pooled) + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = mmdit(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # # Compute regular loss. TODO simplify this + # loss = torch.mean( + # (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + # 1, + # ) + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + if weighting is not None: + loss = loss * weighting + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + optimizer_eval_fn() + sd3_train_utils.sample_images( + accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, + vae, + ) + optimizer_train_fn() + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_mmdit) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + optimizer_eval_fn() + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(clip_l) if train_clip else None, + accelerator.unwrap_model(clip_g) if train_clip else None, + accelerator.unwrap_model(t5xxl) if train_t5xxl else None, + accelerator.unwrap_model(mmdit) if train_mmdit else None, + vae, + ) + + sd3_train_utils.sample_images( + accelerator, args, epoch + 1, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + mmdit = accelerator.unwrap_model(mmdit) + clip_l = accelerator.unwrap_model(clip_l) + clip_g = accelerator.unwrap_model(clip_g) + if t5xxl is not None: + t5xxl = accelerator.unwrap_model(t5xxl) + + accelerator.end_training() + optimizer_eval_fn() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sd3_train_utils.save_sd3_model_on_train_end( + args, + save_dtype, + epoch, + global_step, + clip_l if train_clip else None, + clip_g if train_clip else None, + t5xxl if train_t5xxl else None, + mmdit if train_mmdit else None, + vae, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) + train_util.add_dit_training_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + + parser.add_argument( + "--train_text_encoder", action="store_true", help="train text encoder (CLIP-L and G) / text encoderも学習する" + ) + parser.add_argument("--train_t5xxl", action="store_true", help="train T5-XXL / T5-XXLも学習する") + parser.add_argument( + "--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする" + ) + + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) + parser.add_argument( + "--learning_rate_te3", + type=float, + default=None, + help="learning rate for text encoder 3 (T5-XXL) / text encoder 3 (T5-XXL)の学習率", + ) + + # parser.add_argument( + # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + # ) + # parser.add_argument( + # "--no_half_vae", + # action="store_true", + # help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + # ) + # parser.add_argument( + # "--block_lr", + # type=str, + # default=None, + # help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + # + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + # ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="[DOES NOT WORK] number of optimizer groups for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizerグループ数", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください", + ) + parser.add_argument( + "--num_last_block_to_freeze", + type=int, + default=None, + help="freeze last n blocks of MM-DIT / MM-DITの最後のnブロックを凍結する", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 00000000..fb7711bd --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,480 @@ +import argparse +import copy +import math +import random +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from library import sd3_models, strategy_sd3, utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +import train_network +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + + def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + # super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + # enumerate resolutions from dataset for positional embeddings + self.resolutions = train_dataset_group.get_resolutions() + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + " / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。" + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5_dropout_rate, + ) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = train_util.load_prompts(args.sample_prompts) + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit): + text_encoders = text_encoder # for compatibility + text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) + + sd3_train_utils.sample_images( + accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs + ) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return sd3_models.SDVAE.process_in(latents) + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = Sd3NetworkTrainer() + trainer.train(args) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..b9d52924 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -17,7 +17,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler -from library import deepspeed_utils, sdxl_model_util +from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl import library.train_util as train_util @@ -104,8 +104,8 @@ def train(args): setup_logging(args, reset=True) assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + not args.weighted_captions or not args.cache_text_encoder_outputs + ), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -124,7 +124,16 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -166,10 +175,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -262,8 +271,9 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -272,10 +282,13 @@ def train(args): # 学習を準備する:モデルを適切な状態にする if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - train_unet = args.learning_rate > 0 + train_unet = args.learning_rate != 0 train_text_encoder1 = False train_text_encoder2 = False + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if args.train_text_encoder: # TODO each option for two text encoders? accelerator.print("enable text encoder training") @@ -284,8 +297,8 @@ def train(args): text_encoder2.gradient_checkpointing_enable() lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_text_encoder1 = lr_te1 > 0 - train_text_encoder2 = lr_te2 > 0 + train_text_encoder1 = lr_te1 != 0 + train_text_encoder2 = lr_te2 != 0 # caching one text encoder output is not supported if not train_text_encoder1: @@ -307,16 +320,17 @@ def train(args): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) + + accelerator.wait_for_everyone() if not cache_latents: vae.requires_grad_(False) @@ -345,8 +359,8 @@ def train(args): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,9 +369,59 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # dataloaderを準備する + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -382,7 +446,12 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -450,6 +519,57 @@ def train(args): # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args) + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -487,12 +607,19 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): @@ -504,6 +631,10 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -518,57 +649,39 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states1, encoder_hidden_states2, pool2 = ( + text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + input_ids_list, + weights_list, + ) + ) + else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, + [text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)], + [input_ids1, input_ids2], + ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -582,7 +695,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -590,8 +703,13 @@ def train(args): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred @@ -600,34 +718,42 @@ def train(args): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c) accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -641,7 +767,7 @@ def train(args): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -671,7 +797,7 @@ def train(args): ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} if block_lrs is None: train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) @@ -688,7 +814,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -723,7 +849,7 @@ def train(args): global_step, accelerator.device, vae, - [tokenizer1, tokenizer2], + tokenizers, [text_encoder1, text_encoder2], unet, ) @@ -736,7 +862,7 @@ def train(args): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -805,6 +931,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py new file mode 100644 index 00000000..ffbf03ca --- /dev/null +++ b/sdxl_train_control_net.py @@ -0,0 +1,719 @@ +import argparse +import math +import os +import random +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from accelerate import init_empty_weights +from diffusers import DDPMScheduler +from diffusers.utils.torch_utils import is_compiled_module +from safetensors.torch import load_file +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + unet.to(accelerator.device) # reduce main memory usage + + # convert U-Net to Controlled U-Net + logger.info("convert U-Net to Controlled U-Net") + unet_sd = unet.state_dict() + with init_empty_weights(): + unet = SdxlControlledUNet() + unet.load_state_dict(unet_sd, strict=True, assign=True) + del unet_sd + + # make control net + logger.info("make ControlNet") + if args.controlnet_model_name_or_path: + with init_empty_weights(): + control_net = SdxlControlNet() + + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + info = control_net.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"ControlNet loaded from {filename}: {info}") + else: + control_net = SdxlControlNet() + + logger.info("initialize ControlNet from U-Net") + info = control_net.init_from_unet(unet) + logger.info(f"ControlNet initialized from U-Net: {info}") + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) + + accelerator.wait_for_everyone() + + # モデルに xformers とか memory efficient attention を組み込む + # train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if args.xformers: + unet.set_use_memory_efficient_attention(True, False) + control_net.set_use_memory_efficient_attention(True, False) + elif args.sdpa: + unet.set_use_sdpa(True) + control_net.set_use_sdpa(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + control_net.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + control_net.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + control_net.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + control_net, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + + unet.eval() + control_net.train() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + ("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name), + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" + state_dict = model.state_dict() + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file, sai_metadata) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + control_net.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(control_net): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] + with torch.no_grad(): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2] + ) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + + with accelerator.autocast(): + input_resi_add, mid_add = control_net( + noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image + ) + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = control_net.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, unwrap_model(control_net)) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if len(accelerator.trackers) > 0: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, unwrap_model(control_net)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], + unet, + controlnet=control_net, + ) + + # end of epoch + + if is_main_process: + control_net = unwrap_model(control_net) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, control_net, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + # train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + # train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet modules / controlnetモジュールの学習率", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628..365059b7 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -15,6 +15,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -22,7 +23,16 @@ from accelerate.utils import set_seed import accelerate from diffusers import DDPMScheduler, ControlNetModel from safetensors.torch import load_file -from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util +from library import ( + deepspeed_utils, + sai_model_spec, + sdxl_model_util, + sdxl_original_unet, + sdxl_train_util, + strategy_base, + strategy_sd, + strategy_sdxl, +) import library.model_util as model_util import library.train_util as train_util @@ -78,7 +88,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -105,7 +122,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -163,30 +180,30 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) + + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator) + accelerator.wait_for_everyone() # prepare ControlNet-LLLite @@ -241,7 +258,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -288,6 +309,9 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if isinstance(unet, DDP): + unet._set_static_graph() # avoid error for multiple use of the parameter + if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: @@ -353,7 +377,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -405,27 +431,25 @@ def train(args): latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + # Text Encoder outputs are cached + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list + encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype) + pool2 = pool2.to(accelerator.device, dtype=weight_dtype) + else: + input_ids1, input_ids2 = batch["input_ids_list"] with torch.no_grad(): - # Get the text embedding for conditioning input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2] ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + if args.full_fp16: + encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype) + encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype) + pool2 = pool2.to(weight_dtype) # get size embeddings orig_size = batch["original_sizes_hw"] @@ -439,7 +463,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -458,7 +482,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -471,13 +496,13 @@ def train(args): if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() + params_to_clip = accelerator.unwrap_model(unet).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -512,14 +537,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e85e978c..5b372bef 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -12,6 +12,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -324,7 +325,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -406,7 +409,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -426,7 +429,8 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -439,7 +443,7 @@ def train(args): if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -480,14 +484,14 @@ def train(args): logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1..d45df6e0 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,16 +1,22 @@ import argparse +from typing import List, Optional import torch +from accelerate import Accelerator from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util +from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util import train_network from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() @@ -47,17 +53,41 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.logit_scale = logit_scale self.ckpt_info = ckpt_info + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + return None def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -70,15 +100,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) + accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) @@ -147,7 +173,18 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return encoder_hidden_states1, encoder_hidden_states2, pool2 - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings @@ -161,6 +198,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e2..821a6955 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -5,10 +5,10 @@ import regex import torch from library.device_utils import init_ipex + init_ipex() -from library import sdxl_model_util, sdxl_train_util, train_util - +from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util import train_textual_inversion @@ -41,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -81,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -122,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine def setup_parser() -> argparse.ArgumentParser: parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) return parser diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..9836da8b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,41 @@ +# Tests + +## Install + +``` +pip install pytest +``` + +## Usage + +``` +pytest +``` + +## Contribution + +Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests. + +Tests are functions starting with `test_` and files with the pattern `test_*.py`. + +``` +def test_x(): + assert 1 == 2, "Invalid test response" +``` + +## Resources + +### pytest + +- https://docs.pytest.org/en/stable/index.html +- https://docs.pytest.org/en/stable/how-to/assert.html +- https://docs.pytest.org/en/stable/how-to/doctest.html + +### PyTorch testing + +- https://circleci.com/blog/testing-pytorch-model-with-pytest/ +- https://pytorch.org/docs/stable/testing.html +- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +- https://github.com/huggingface/pytorch-image-models/tree/main/tests +- https://github.com/pytorch/pytorch/tree/main/test + diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 00000000..f6ade91a --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,153 @@ +from unittest.mock import patch +from library.train_util import get_optimizer +from train_network import setup_parser +import torch +from torch.nn import Parameter + +# Optimizer libraries +import bitsandbytes as bnb +from lion_pytorch import lion_pytorch +import schedulefree + +import dadaptation +import dadaptation.experimental as dadapt_experimental + +import prodigyopt +import schedulefree as sf +import transformers + + +def test_default_get_optimizer(): + with patch("sys.argv", [""]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "torch.optim.adamw.AdamW" + assert optimizer_args == "" + assert isinstance(optimizer, torch.optim.AdamW) + + +def test_get_schedulefree_optimizer(): + with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param]) + assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree" + assert optimizer_args == "" + assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree) + + +def test_all_supported_optimizers(): + optimizers = [ + { + "name": "bitsandbytes.optim.adamw.AdamW8bit", + "alias": "AdamW8bit", + "instance": bnb.optim.AdamW8bit, + }, + { + "name": "lion_pytorch.lion_pytorch.Lion", + "alias": "Lion", + "instance": lion_pytorch.Lion, + }, + { + "name": "torch.optim.adamw.AdamW", + "alias": "AdamW", + "instance": torch.optim.AdamW, + }, + { + "name": "bitsandbytes.optim.lion.Lion8bit", + "alias": "Lion8bit", + "instance": bnb.optim.Lion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW8bit", + "alias": "PagedAdamW8bit", + "instance": bnb.optim.PagedAdamW8bit, + }, + { + "name": "bitsandbytes.optim.lion.PagedLion8bit", + "alias": "PagedLion8bit", + "instance": bnb.optim.PagedLion8bit, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW", + "alias": "PagedAdamW", + "instance": bnb.optim.PagedAdamW, + }, + { + "name": "bitsandbytes.optim.adamw.PagedAdamW32bit", + "alias": "PagedAdamW32bit", + "instance": bnb.optim.PagedAdamW32bit, + }, + {"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD}, + { + "name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint", + "alias": "DAdaptAdamPreprint", + "instance": dadapt_experimental.DAdaptAdamPreprint, + }, + { + "name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad", + "alias": "DAdaptAdaGrad", + "instance": dadaptation.DAdaptAdaGrad, + }, + { + "name": "dadaptation.dadapt_adan.DAdaptAdan", + "alias": "DAdaptAdan", + "instance": dadaptation.DAdaptAdan, + }, + { + "name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP", + "alias": "DAdaptAdanIP", + "instance": dadapt_experimental.DAdaptAdanIP, + }, + { + "name": "dadaptation.dadapt_lion.DAdaptLion", + "alias": "DAdaptLion", + "instance": dadaptation.DAdaptLion, + }, + { + "name": "dadaptation.dadapt_sgd.DAdaptSGD", + "alias": "DAdaptSGD", + "instance": dadaptation.DAdaptSGD, + }, + { + "name": "prodigyopt.prodigy.Prodigy", + "alias": "Prodigy", + "instance": prodigyopt.Prodigy, + }, + { + "name": "transformers.optimization.Adafactor", + "alias": "Adafactor", + "instance": transformers.optimization.Adafactor, + }, + { + "name": "schedulefree.adamw_schedulefree.AdamWScheduleFree", + "alias": "AdamWScheduleFree", + "instance": sf.AdamWScheduleFree, + }, + { + "name": "schedulefree.sgd_schedulefree.SGDScheduleFree", + "alias": "SGDScheduleFree", + "instance": sf.SGDScheduleFree, + }, + ] + + for opt in optimizers: + with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]): + parser = setup_parser() + args = parser.parse_args() + params_t = torch.tensor([1.5, 1.5]) + + param = Parameter(params_t) + optimizer_name, _, optimizer = get_optimizer(args, [param]) + assert optimizer_name == opt.get("name") + + instance = opt.get("instance") + assert instance is not None + assert isinstance(optimizer, instance) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f..c034f949 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -9,47 +9,82 @@ from accelerate.utils import set_seed import torch from tqdm import tqdm -from library import config_util +from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl from library import train_util from library import sdxl_train_util from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments + setup_logging() import logging + logger = logging.getLogger(__name__) -def cache_to_disk(args: argparse.Namespace) -> None: - train_util.prepare_dataset_args(args, True) - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" +def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None: + if is_flux: + _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path) + else: + is_schnell = False + + if is_sd: + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + elif is_sdxl: + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + else: + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + +def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) + train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) + + # assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + args.cache_latents = True + args.cache_latents_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) + + if is_sd or is_sdxl: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check) else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -80,20 +115,15 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -102,79 +132,43 @@ def cache_to_disk(args: argparse.Namespace) -> None: # モデルを読み込む logger.info("load model") - if args.sdxl: + if is_sd: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + elif is_sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + if is_sd or is_sdxl: + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + # cache latents with dataset + # TODO use DataLoader to speed up + train_dataset_group.new_cache_latents(vae, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching latents to disk.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") parser.add_argument( "--no_half_vae", action="store_true", @@ -183,7 +177,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", ) return parser diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5f1d6d20..5888b8e3 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -9,54 +9,69 @@ from accelerate.utils import set_seed import torch from tqdm import tqdm -from library import config_util +from library import ( + config_util, + flux_train_utils, + flux_utils, + sdxl_model_util, + strategy_base, + strategy_flux, + strategy_sd, + strategy_sdxl, +) from library import train_util from library import sdxl_train_util +from library import utils from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) -from library.utils import setup_logging +from library.utils import setup_logging, add_logging_arguments +from cache_latents import set_tokenize_strategy + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: + setup_logging(args, reset=True) train_util.prepare_dataset_args(args, True) + train_util.enable_high_vram(args) - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + args.cache_text_encoder_outputs = True + args.cache_text_encoder_outputs_to_disk = True use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] + is_sd = not args.sdxl and not args.flux + is_sdxl = args.sdxl + is_flux = args.flux + + assert ( + is_sdxl or is_flux + ), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です" + assert ( + is_sdxl or args.weighted_captions is None + ), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です" + + set_tokenize_strategy(is_sd, is_sdxl, is_flux, args) # データセットを準備する + use_user_config = args.dataset_config is not None if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) @@ -87,100 +102,117 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) # acceleratorを準備する logger.info("prepare accelerator") + args.deepspeed = False accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) + t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype) # モデルを読み込む logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + if is_sdxl: + _, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + text_encoder1.to(accelerator.device, weight_dtype) + text_encoder2.to(accelerator.device, weight_dtype) text_encoders = [text_encoder1, text_encoder2] else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] + clip_l = flux_utils.load_clip_l( + args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors + ) + + t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors) + + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + if t5xxl_dtype != t5xxl_dtype: + if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2: + logger.warning( + "The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop." + " / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。" + ) + logger.info(f"Casting T5XXL model to {t5xxl_dtype}") + t5xxl.to(t5xxl_dtype) + + text_encoders = [clip_l, t5xxl] for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") + # build text encoder outputs caching strategy + if is_sdxl: + text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=False, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + # build text encoding strategy + if is_sdxl: + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + else: + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) + # cache text encoder outputs + train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator) accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + accelerator.print(f"Finished caching text encoder outputs to disk.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_training_arguments(parser, True) train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_masked_loss_arguments(parser) config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する") + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)", + ) parser.add_argument( "--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check." + " / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。", + ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", ) return parser diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py new file mode 100644 index 00000000..65ba7321 --- /dev/null +++ b/tools/convert_diffusers_to_flux.py @@ -0,0 +1,149 @@ +# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model. +# It is based on the implementation by 2kpr. Thanks to 2kpr! +# Major changes: +# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once. +# - Makes reverse map from diffusers map to avoid loading all tensors. +# - Removes dependency on .json file for weights mapping. +# - Adds support for custom memory efficient load and save functions. +# - Supports saving with different precision. +# - Supports .safetensors file as input. + +# Copyright 2024 2kpr. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import os +from pathlib import Path +import safetensors +from safetensors.torch import safe_open +import torch +from tqdm import tqdm + +from library import flux_utils +from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def convert(args): + # if diffusers_path is folder, get safetensors file + diffusers_path = Path(args.diffusers_path) + if diffusers_path.is_dir(): + diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + + flux_path = Path(args.save_to) + if not os.path.exists(flux_path.parent): + os.makedirs(flux_path.parent) + + if not diffusers_path.exists(): + logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}") + return + + mem_eff_flag = args.mem_eff_load_save + save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None + + # make reverse map from diffusers map + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map() + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for i in range(3): + # replace 00001 with 0000i + current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}")) + logger.info(f"Loading diffusers file: {current_diffusers_path}") + + open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt")) + with open_func(current_diffusers_path) as f: + for diffusers_key in tqdm(f.keys()): + if diffusers_key in diffusers_to_bfl_map: + tensor = f.get_tensor(diffusers_key).to("cpu") + if save_dtype is not None: + tensor = tensor.to(save_dtype) + + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + return + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + # save flux_sd to safetensors file + logger.info(f"Saving Flux safetensors file: {flux_path}") + if mem_eff_flag: + mem_eff_save_file(flux_sd, flux_path) + else: + safetensors.torch.save_file(flux_sd, flux_path) + + logger.info("Conversion completed.") + + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--diffusers_path", + default=None, + type=str, + required=True, + help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file." + " / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス", + ) + parser.add_argument( + "--save_to", + default=None, + type=str, + required=True, + help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先", + ) + parser.add_argument( + "--mem_eff_load_save", + action="store_true", + help="use custom memory efficient load and save functions for FLUX.1 model" + " / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + help="precision in saving, default is same as loading precision" + "float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz" + " / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + convert(args) diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643ed..d2a4d9cf 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ import os from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1..0f9e00b1 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import shutil import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2] diff --git a/train_control_net.py b/train_control_net.py new file mode 100644 index 00000000..177d2b11 --- /dev/null +++ b/train_control_net.py @@ -0,0 +1,669 @@ +import argparse +import json +import math +import os +import random +import time +from multiprocessing import Value + +# from omegaconf import OmegaConf +import toml + +from tqdm import tqdm + +import torch +from library import deepspeed_utils +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + pyramid_noise_like, + apply_noise_offset, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + # session_id = random.randint(0, 2**32) + # training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(64) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True + ) + + # DiffusersのControlNetが使用するデータを準備する + if args.v2: + unet.config = { + "act_fn": "silu", + "attention_head_dim": [5, 10, 20, 20], + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 1024, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": [5, 10, 20, 20], + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "sample_size": 96, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "use_linear_projection": True, + "upcast_attention": True, + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": True, + "class_embed_type": None, + "num_class_embeds": None, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + else: + unet.config = { + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 768, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": 8, + "out_channels": 4, + "sample_size": 64, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": False, + "class_embed_type": None, + "num_class_embeds": None, + "upcast_attention": False, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + # unet.config = OmegaConf.create(unet.config) + + # make unet.config iterable and accessible by attribute + class CustomConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __contains__(self, name): + return name in self.__dict__ + + unet.config = CustomConfig(**unet.config) + + controlnet = ControlNetModel.from_unet(unet) + + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(controlnet.parameters()) + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + controlnet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + if args.fused_backward_pass: + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.to(accelerator.device) + text_encoder.to(accelerator.device) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet + ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + # training loop + for epoch in range(num_train_epochs): + if is_main_process: + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, + latents.device, + args.multires_noise_iterations, + args.multires_noise_discount, + ) + + # Sample a random timestep for each image + timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model( + ckpt_name, + accelerator.unwrap_model(controlnet), + ) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if len(accelerator.trackers) > 0: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # end of epoch + if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, controlnet, force_sync_upload=True) + + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_controlnet.py b/train_controlnet.py index e55814ea..09a911a0 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,40 +1,4 @@ -import argparse -import json -import math -import os -import random -import time -from multiprocessing import Value -from types import SimpleNamespace -import toml - -from tqdm import tqdm - -import torch -from library import deepspeed_utils -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from accelerate.utils import set_seed -from diffusers import DDPMScheduler, ControlNetModel -from safetensors.torch import load_file - -import library.model_util as model_util -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - pyramid_noise_like, - apply_noise_offset, -) -from library.utils import setup_logging, add_logging_arguments +from library.utils import setup_logging setup_logging() import logging @@ -42,6 +6,7 @@ import logging logger = logging.getLogger(__name__) +<<<<<<< HEAD # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = { @@ -608,8 +573,16 @@ def setup_parser() -> argparse.ArgumentParser: return parser +======= +from library import train_util +from train_control_net import setup_parser, train +>>>>>>> hina/feature/val-loss if __name__ == "__main__": + logger.warning( + "The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead" + " / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。" + ) parser = setup_parser() args = parser.parse_args() diff --git a/train_db.py b/train_db.py index 2c17e521..398489ff 100644 --- a/train_db.py +++ b/train_db.py @@ -10,7 +10,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base from library.device_utils import init_ipex, clean_memory_on_device @@ -37,6 +37,7 @@ from library.custom_train_functions import ( apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd setup_logging() import logging @@ -119,7 +120,14 @@ def train(args): if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -141,10 +149,10 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) val_dataset_group = None current_epoch = Value("i", 0) @@ -155,6 +163,8 @@ def train(args): if args.no_token_padding: train_dataset_group.disable_token_padding() + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -207,16 +217,17 @@ def train(args): vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 @@ -249,8 +260,11 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -364,10 +378,19 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() val_loss_recorder = train_util.LossRecorder() @@ -406,23 +429,21 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -434,8 +455,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -447,7 +469,7 @@ def train(args): if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -469,7 +491,7 @@ def train(args): global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet ) # 指定ステップごとにモデルを保存 @@ -494,7 +516,7 @@ def train(args): ) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) @@ -526,8 +548,8 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: - logs = {"loss/epoch_average": loss_recorder.moving_average} + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -552,7 +574,9 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7c1d8ce2..99b9717a 100644 --- a/train_network.py +++ b/train_network.py @@ -9,7 +9,9 @@ import random import time import json from multiprocessing import Value +from typing import Any, List import toml +import itertools from tqdm import tqdm @@ -20,9 +22,10 @@ from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed +from accelerate import Accelerator from diffusers import DDPMScheduler, AutoencoderKL from diffusers.models.modeling_outputs import AutoencoderKLOutput -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -58,7 +61,16 @@ class NetworkTrainer: # TODO 他のスクリプトと共通化する def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer=None, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -68,21 +80,31 @@ class NetworkTrainer: logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() - - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" + else: + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] ) else: idx = 0 @@ -96,39 +118,69 @@ class NetworkTrainer: logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + ) return logs def assert_extra_args(self, args, train_dataset_group): + # train_dataset_group.verify_bucket_reso_steps(64) + # TODO: Number of bucket reso steps may differ for each model, so a static number won't work + # and prevents models like SD1.5 with 64 pass def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) - def is_text_encoder_outputs_cached(self, args): - return False + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). + """ + return text_encoders + + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + return not args.network_train_unet_only - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): @@ -136,10 +188,136 @@ class NetworkTrainer: if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) - def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None): + # region SD/SDXL + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + + return noise_pred, target, timesteps, None + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + + # endregion + + def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -154,21 +332,42 @@ class NetworkTrainer: latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)) latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor) - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizers[0], - text_encoders[0], - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + + text_encoder_conds = [] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] batch_size = latents.shape[0] @@ -204,9 +403,36 @@ class NetworkTrainer: else: target = noise + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight loss = loss * loss_weights @@ -221,7 +447,7 @@ class NetworkTrainer: total_loss += loss - return total_loss.mean() + return total_loss / len(choosen_timesteps_list) def train(self, args): session_id = random.randint(0, 2**32) @@ -239,9 +465,13 @@ class NetworkTrainer: args.seed = random.randint(0, 2**32) set_seed(args.seed) - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する if args.dataset_class is None: @@ -283,11 +513,11 @@ class NetworkTrainer: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + train_dataset_group = train_util.load_arbitrary_dataset(args) val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) @@ -296,6 +526,7 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -311,9 +542,9 @@ class NetworkTrainer: if val_dataset_group is not None: assert ( val_dataset_group.is_latent_cacheable() - ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - self.assert_extra_args(args, train_dataset_group) + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -330,11 +561,6 @@ class NetworkTrainer: # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) accelerator.print("import network module:", args.network_module) @@ -362,11 +588,11 @@ class NetworkTrainer: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + + train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) + vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -374,9 +600,15 @@ class NetworkTrainer: # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -415,36 +647,79 @@ class NetworkTrainer: ) args.scale_weight_norms = False + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: + # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) - except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None + except TypeError as e: + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) + lr_descriptions = None + + # if len(trainable_params) == 0: + # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") + # for params in trainable_params: + # for k, v in params.items(): + # if type(v) == float: + # pass + # else: + # v = len(v) + # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() - # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -478,6 +753,10 @@ class NetworkTrainer: # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # Not for sure here. + # if val_dataset_group is not None: + # val_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -497,33 +776,47 @@ class NetworkTrainer: unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn + + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: + for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( @@ -532,15 +825,19 @@ class NetworkTrainer: training_model = ds_model else: if train_unet: - unet = accelerator.prepare(unet) + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set @@ -552,12 +849,12 @@ class NetworkTrainer: if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: - t_enc.text_model.embeddings.requires_grad_(True) + if frag: + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() @@ -580,15 +877,26 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -599,6 +907,15 @@ class NetworkTrainer: models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -633,7 +950,7 @@ class NetworkTrainer: "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, @@ -679,9 +996,14 @@ class NetworkTrainer: "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), } + self.update_metadata(metadata, args) # architecture specific metadata + if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time @@ -842,15 +1164,57 @@ class NetworkTrainer: if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + initial_step = 0 # do not skip + global_step = 0 - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) if accelerator.is_main_process: init_kwargs = {} @@ -859,18 +1223,23 @@ class NetworkTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() val_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: - on_step_start = lambda *args, **kwargs: None + on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -883,7 +1252,7 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) @@ -896,11 +1265,45 @@ class NetworkTrainer: accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # if text_encoder is not needed for training, delete it to save memory. + # TODO this can be automated after SDXL sample prompt cache is implemented + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + text_encoder = None + # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + optimizer_eval_fn() + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() + is_tracking = len(accelerator.trackers) > 0 + if is_tracking: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) + + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) # training loop - for epoch in range(num_train_epochs): + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step + + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -909,12 +1312,19 @@ class NetworkTrainer: accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # TRAINING + skipped_dataloader = None + if initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) + initial_step = 1 - for step, batch in enumerate(train_dataloader): + for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -939,7 +1349,10 @@ class NetworkTrainer: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + optimizer_eval_fn() + self.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet + ) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -955,6 +1368,7 @@ class NetworkTrainer: if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + optimizer_train_fn() current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) @@ -965,72 +1379,85 @@ class NetworkTrainer: if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) - - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - # batch = next(val_dataloader) - for val_step, batch in tqdm(enumerate(val_dataloader), desc='Validation Steps'): - if val_step >= validation_steps: - break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) - - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} + if is_tracking: + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + ) + accelerator.log(logs, step=global_step) + + # VALIDATION PER STEP + should_validate = (args.validation_every_n_step is not None + and global_step % args.validation_every_n_step == 0) + if validation_steps > 0 and should_validate: + accelerator.print("Validating バリデーション処理...") + + val_progress_bar = tqdm( + range(validation_steps), smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps" + ) + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + + val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) + val_progress_bar.update(1) + val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + + if is_tracking: + logs = {"loss/current_val_loss": loss.detach().item()} + accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + + logs = {"loss/average_val_loss": val_loss_recorder.moving_average} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - # VALIDATION - + # VALIDATION EPOCH if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + accelerator.print("Validating バリデーション処理...") + val_progress_bar = tqdm( + range(validation_steps), smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps" + ) - for val_step, batch in enumerate(val_dataloader): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + for val_step, batch in enumerate(val_dataloader): + loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) - current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_current": current_loss} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + if is_tracking: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - if len(val_dataloader) > 0: - if args.logging_dir is not None: + if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_average": avr_loss} accelerator.log(logs, step=epoch + 1) - - if args.logging_dir is not None: + # END OF EPOCH + if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - if len(val_dataloader) > 0: - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0 and is_tracking: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 + optimizer_eval_fn() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: @@ -1045,7 +1472,8 @@ class NetworkTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + optimizer_train_fn() # end of epoch @@ -1056,6 +1484,7 @@ class NetworkTrainer: network = accelerator.unwrap_model(network) accelerator.end_training() + optimizer_eval_fn() if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) @@ -1080,6 +1509,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) @@ -1092,7 +1527,19 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" @@ -1169,30 +1616,49 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) parser.add_argument( "--validation_seed", type=int, default=None, - help="Validation seed" + help="Validation seed / 検証シード" ) parser.add_argument( "--validation_split", type=float, default=0.0, - help="Split for validation images out of the training dataset" - ) + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) parser.add_argument( "--validation_every_n_step", type=int, default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" - ) + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" - ) + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + ) return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce267..65da4859 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,6 +2,7 @@ import argparse import math import os from multiprocessing import Value +from typing import Any, List import toml from tqdm import tqdm @@ -15,7 +16,7 @@ init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer -from library import deepspeed_utils, model_util +from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -99,32 +100,42 @@ class TextualInversionTrainer: self.is_sdxl = False def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): pass - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]: + return text_encoders def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images( + self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement + ): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -182,8 +193,13 @@ class TextualInversionTrainer: if args.seed is not None: set_seed(args.seed) - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # acceleratorを準備する logger.info("prepare accelerator") @@ -194,14 +210,7 @@ class TextualInversionTrainer: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) + model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id init_token_ids_list = [] @@ -310,10 +319,10 @@ class TextualInversionTrainer: ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) + train_dataset_group = train_util.load_arbitrary_dataset(args) self.assert_extra_args(args, train_dataset_group) @@ -368,11 +377,10 @@ class TextualInversionTrainer: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) + train_dataset_group.new_cache_latents(vae, accelerator) + + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -387,7 +395,11 @@ class TextualInversionTrainer: trainable_params += text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) - # dataloaderを準備する + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -415,20 +427,8 @@ class TextualInversionTrainer: lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders] index_no_updates_list = [] orig_embeds_params_list = [] @@ -456,6 +456,9 @@ class TextualInversionTrainer: else: unet.eval() + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() @@ -510,7 +513,9 @@ class TextualInversionTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -540,11 +545,14 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) + if len(accelerator.trackers) > 0: + # log empty object to commit the sample images to wandb + accelerator.log({}, step=0) # training loop for epoch in range(num_train_epochs): @@ -568,11 +576,16 @@ class TextualInversionTrainer: latents = latents * self.vae_scale_factor # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents ) @@ -588,8 +601,9 @@ class TextualInversionTrainer: else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -603,7 +617,7 @@ class TextualInversionTrainer: if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -639,8 +653,8 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) @@ -672,7 +686,7 @@ class TextualInversionTrainer: remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -690,7 +704,7 @@ class TextualInversionTrainer: if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1) @@ -722,11 +736,12 @@ class TextualInversionTrainer: global_step, accelerator.device, vae, - tokenizer_or_list, - text_encoder_or_list, + tokenizers, + text_encoders, unet, prompt_replacement, ) + accelerator.log({}) # end of epoch diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ddd03d53..2a2b4231 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) # function for saving/removing @@ -461,7 +463,7 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -473,8 +475,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -486,7 +489,7 @@ def train(args): if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -538,7 +541,7 @@ def train(args): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} if ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() @@ -556,7 +559,7 @@ def train(args): if global_step >= args.max_train_steps: break - if args.logging_dir is not None: + if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch + 1)