mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
add stochastic rounding, fix single block
This commit is contained in:
19
README.md
19
README.md
@@ -9,6 +9,15 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 21, 2024 (update 3):
|
||||||
|
- There is a bug that `--full_bf16` option is enabled even if it is not specified in `flux_train.py`. The bug will be fixed sooner. __Please specify the `--full_bf16` option explicitly, especially when training with 24GB VRAM.__
|
||||||
|
- Stochastic rounding is now implemented when `--fused_backward_pass` is specified. The implementation is
|
||||||
|
based on the code provided by 2kpr. Thank you so much!
|
||||||
|
- With this change, `--fused_backward_pass` is recommended over `--blockwise_fused_optimizers` when `--full_bf16` is specified.
|
||||||
|
- Please note that `--fused_backward_pass` is only supported with Adafactor.
|
||||||
|
- The sample command in [FLUX.1 fine-tuning](#flux1-fine-tuning) is updated to reflect these changes.
|
||||||
|
- Fixed `--single_blocks_to_swap` is not working in `flux_train.py`.
|
||||||
|
|
||||||
Aug 21, 2024 (update 2):
|
Aug 21, 2024 (update 2):
|
||||||
Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool.
|
Fixed an error in applying mask in Attention. The attention mask was float, but it should be bool.
|
||||||
|
|
||||||
@@ -142,7 +151,7 @@ accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_t
|
|||||||
--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1
|
--learning_rate 5e-5 --max_train_epochs 4 --sdpa --highvram --cache_text_encoder_outputs_to_disk --cache_latents_to_disk --save_every_n_epochs 1
|
||||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
|
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False"
|
||||||
--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0
|
--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0
|
||||||
--blockwise_fused_optimizers --double_blocks_to_swap 6 --cpu_offload_checkpointing
|
--fused_backward_pass --double_blocks_to_swap 6 --cpu_offload_checkpointing --full_bf16
|
||||||
```
|
```
|
||||||
|
|
||||||
(Combine the command into one line.)
|
(Combine the command into one line.)
|
||||||
@@ -151,9 +160,13 @@ Sample image generation during training is not tested yet.
|
|||||||
|
|
||||||
Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available.
|
Options are almost the same as LoRA training. The difference is `--blockwise_fused_optimizers`, `--double_blocks_to_swap` and `--cpu_offload_checkpointing`. `--single_blocks_to_swap` is also available.
|
||||||
|
|
||||||
`--blockwise_fused_optimizers` enables the fusing of the optimizer for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--fused_optimizer_groups` is deprecated due to the addition of this option for FLUX.1 training.
|
`--full_bf16` enables the training with bf16 (weights and gradients).
|
||||||
|
|
||||||
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--blockwise_fused_optimizers`.
|
`--fused_backward_pass` enables the fusing of the optimizer step into the backward pass for each parameter. This reduces the memory usage during training. Only Adafactor optimizer is supported for now. Stochastic rounding is also enabled when `--fused_backward_pass` and `--full_bf16` are specified.
|
||||||
|
|
||||||
|
`--blockwise_fused_optimizers` enables the fusing of the optimizer step into the backward pass for each block. This is similar to `--fused_backward_pass`. Any optimizer can be used, but Adafactor is recommended for memory efficiency. `--blockwise_fused_optimizers` cannot be used with `--fused_backward_pass`. Stochastic rounding is not supported for now.
|
||||||
|
|
||||||
|
`--double_blocks_to_swap` and `--single_blocks_to_swap` are the number of double blocks and single blocks to swap. The default is None (no swap). These options must be combined with `--fused_backward_pass` or `--blockwise_fused_optimizers`. `--double_blocks_to_swap` can be specified with `--single_blocks_to_swap`. The recommended maximum number of blocks to swap is 9 for double blocks and 18 for single blocks.
|
||||||
|
|
||||||
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
|
`--cpu_offload_checkpointing` is to offload the gradient checkpointing to CPU. This reduces about 2GB of VRAM usage.
|
||||||
|
|
||||||
|
|||||||
@@ -277,7 +277,10 @@ def train(args):
|
|||||||
training_models = []
|
training_models = []
|
||||||
params_to_optimize = []
|
params_to_optimize = []
|
||||||
training_models.append(flux)
|
training_models.append(flux)
|
||||||
params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate})
|
name_and_params = list(flux.named_parameters())
|
||||||
|
# single param group for now
|
||||||
|
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
|
||||||
|
param_names = [[n for n, _ in name_and_params]]
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
# calculate number of trainable parameters
|
||||||
n_params = 0
|
n_params = 0
|
||||||
@@ -433,9 +436,79 @@ def train(args):
|
|||||||
import library.adafactor_fused
|
import library.adafactor_fused
|
||||||
|
|
||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
for parameter in param_group["params"]:
|
double_blocks_to_swap = args.double_blocks_to_swap
|
||||||
|
single_blocks_to_swap = args.single_blocks_to_swap
|
||||||
|
num_double_blocks = len(flux.double_blocks)
|
||||||
|
num_single_blocks = len(flux.single_blocks)
|
||||||
|
handled_double_block_indices = set()
|
||||||
|
handled_single_block_indices = set()
|
||||||
|
|
||||||
|
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||||
|
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||||
if parameter.requires_grad:
|
if parameter.requires_grad:
|
||||||
|
grad_hook = None
|
||||||
|
|
||||||
|
if double_blocks_to_swap:
|
||||||
|
if param_name.startswith("double_blocks"):
|
||||||
|
block_idx = int(param_name.split(".")[1])
|
||||||
|
if (
|
||||||
|
block_idx not in handled_double_block_indices
|
||||||
|
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
|
||||||
|
and block_idx < num_double_blocks - 1
|
||||||
|
):
|
||||||
|
# swap next (already backpropagated) block
|
||||||
|
handled_double_block_indices.add(block_idx)
|
||||||
|
block_idx_cpu = block_idx + 1
|
||||||
|
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
|
||||||
|
|
||||||
|
# create swap hook
|
||||||
|
def create_double_swap_grad_hook(bidx, bidx_cuda):
|
||||||
|
def __grad_hook(tensor: torch.Tensor):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||||
|
optimizer.step_param(tensor, param_group)
|
||||||
|
tensor.grad = None
|
||||||
|
|
||||||
|
# swap blocks if necessary
|
||||||
|
flux.double_blocks[bidx].to("cpu")
|
||||||
|
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
||||||
|
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
||||||
|
|
||||||
|
return __grad_hook
|
||||||
|
|
||||||
|
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
||||||
|
if single_blocks_to_swap:
|
||||||
|
if param_name.startswith("single_blocks"):
|
||||||
|
block_idx = int(param_name.split(".")[1])
|
||||||
|
if (
|
||||||
|
block_idx not in handled_single_block_indices
|
||||||
|
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
|
||||||
|
and block_idx < num_single_blocks - 1
|
||||||
|
):
|
||||||
|
handled_single_block_indices.add(block_idx)
|
||||||
|
block_idx_cpu = block_idx + 1
|
||||||
|
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
|
||||||
|
# print(param_name, block_idx_cpu, block_idx_cuda)
|
||||||
|
|
||||||
|
# create swap hook
|
||||||
|
def create_single_swap_grad_hook(bidx, bidx_cuda):
|
||||||
|
def __grad_hook(tensor: torch.Tensor):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||||
|
optimizer.step_param(tensor, param_group)
|
||||||
|
tensor.grad = None
|
||||||
|
|
||||||
|
# swap blocks if necessary
|
||||||
|
flux.single_blocks[bidx].to("cpu")
|
||||||
|
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
||||||
|
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
||||||
|
|
||||||
|
return __grad_hook
|
||||||
|
|
||||||
|
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
||||||
|
|
||||||
|
if grad_hook is None:
|
||||||
|
|
||||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
@@ -443,7 +516,9 @@ def train(args):
|
|||||||
optimizer.step_param(tensor, param_group)
|
optimizer.step_param(tensor, param_group)
|
||||||
tensor.grad = None
|
tensor.grad = None
|
||||||
|
|
||||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
grad_hook = __grad_hook
|
||||||
|
|
||||||
|
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||||
|
|
||||||
elif args.blockwise_fused_optimizers:
|
elif args.blockwise_fused_optimizers:
|
||||||
# prepare for additional optimizers and lr schedulers
|
# prepare for additional optimizers and lr schedulers
|
||||||
|
|||||||
@@ -2,6 +2,32 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
from transformers import Adafactor
|
from transformers import Adafactor
|
||||||
|
|
||||||
|
# stochastic rounding for bfloat16
|
||||||
|
# The implementation was provided by 2kpr. Thank you very much!
|
||||||
|
|
||||||
|
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
||||||
|
"""
|
||||||
|
copies source into target using stochastic rounding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: the target tensor with dtype=bfloat16
|
||||||
|
source: the target tensor with dtype=float32
|
||||||
|
"""
|
||||||
|
# create a random 16 bit integer
|
||||||
|
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
||||||
|
|
||||||
|
# add the random number to the lower 16 bit of the mantissa
|
||||||
|
result.add_(source.view(dtype=torch.int32))
|
||||||
|
|
||||||
|
# mask off the lower 16 bit of the mantissa
|
||||||
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||||
|
|
||||||
|
# copy the higher 16 bit into the target tensor
|
||||||
|
target.copy_(result.view(dtype=torch.float32))
|
||||||
|
|
||||||
|
del result
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def adafactor_step_param(self, p, group):
|
def adafactor_step_param(self, p, group):
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@@ -78,7 +104,12 @@ def adafactor_step_param(self, p, group):
|
|||||||
|
|
||||||
p_data_fp32.add_(-update)
|
p_data_fp32.add_(-update)
|
||||||
|
|
||||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
# if p.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
# p.copy_(p_data_fp32)
|
||||||
|
|
||||||
|
if p.dtype == torch.bfloat16:
|
||||||
|
copy_stochastic_(p, p_data_fp32)
|
||||||
|
elif p.dtype == torch.float16:
|
||||||
p.copy_(p_data_fp32)
|
p.copy_(p_data_fp32)
|
||||||
|
|
||||||
|
|
||||||
@@ -101,6 +132,7 @@ def adafactor_step(self, closure=None):
|
|||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def patch_adafactor_fused(optimizer: Adafactor):
|
def patch_adafactor_fused(optimizer: Adafactor):
|
||||||
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
||||||
optimizer.step = adafactor_step.__get__(optimizer)
|
optimizer.step = adafactor_step.__get__(optimizer)
|
||||||
|
|||||||
@@ -1078,6 +1078,7 @@ class Flux(nn.Module):
|
|||||||
if moving:
|
if moving:
|
||||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||||
# print(f"Moved single block {to_cpu_block_index} to cpu.")
|
# print(f"Moved single block {to_cpu_block_index} to cpu.")
|
||||||
|
to_cpu_block_index += 1
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user