Compare commits

...

10 Commits

Author SHA1 Message Date
araleza
cb2f5975e6 Merge ac8ae581db into fa53f71ec0 2026-04-03 02:09:52 +08:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
araleza
ac8ae581db Removed some no-effect lines used for a debug breakpoint 2025-08-20 16:49:37 +01:00
araleza
cd239f0fa9 Moved kahan state from file globals to optimizer state variables 2025-08-20 16:42:15 +01:00
araleza
648994271e Added log output message to show that Kahan summation is being used 2025-07-29 10:28:26 +01:00
araleza
3f0230a286 Now sending int16s instead of f32s to cpu device; faster and maybe more accurate 2025-07-29 10:05:06 +01:00
araleza
acb4cf32e8 Fixed a warning typo, and changed --kahan-summation to --kahan_summation 2025-07-23 18:25:07 +01:00
araleza
bb7750fbca Fixed typo in comment 2025-07-23 15:10:57 +01:00
araleza
da6416a2fc Restoring the deleted __main__ function and fixing a warning typo 2025-07-23 15:08:24 +01:00
araleza
6517b2b838 Added support for Kahan summation for Adafactor-optimized Flux FFT 2025-07-23 14:34:32 +01:00
5 changed files with 103 additions and 6 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -381,10 +381,27 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
if args.optimizer_type == "adafactor" and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizer with Adafactor optimizer prevents stochastic/kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type != "adafactor":
logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.")
elif not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
elif args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
else:
logger.info("Using Kahan summation")
optimizer.use_kahan_summation = args.kahan_summation
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -816,6 +833,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",

View File

@@ -28,6 +28,60 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result
# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update):
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
state: the optimizer state, used to store kahan residuals
update: the change in weights due to the gradient
"""
# Initialize residuals to 0 for first step
if state.get('kahan_residuals') is None:
state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16)
# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)
# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(state['kahan_residuals'])
# If the Kahan residual was >=0.5 then the cast to bf16 rounded up
rounded_up = state['kahan_residuals'] >= 32768
source_i32[rounded_up] -= 65536
# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)
# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals'])
source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source
# Move the 16-bit Kahan bits from VRAM to main memory
state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu")
# Copy the quantized floats into the target tensor
target.copy_(source)
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -102,13 +156,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp