Added support for Kahan summation for Adafactor-optimized Flux FFT

This commit is contained in:
araleza
2025-07-23 14:34:32 +01:00
parent 4987057701
commit 6517b2b838
2 changed files with 77 additions and 11 deletions

View File

@@ -381,10 +381,25 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
if args.optimizer_type == "adafactor" and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type != "adafactor":
logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.")
if not args.full_bf16:
logger.warning("Kahan summation require --full_bf16")
if args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
optimizer.use_kahan_summation = args.kahan_summation
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -815,6 +830,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan-summation",
action="store_true",
help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
@@ -838,13 +859,3 @@ def setup_parser() -> argparse.ArgumentParser:
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -28,6 +28,58 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result
# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Base on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
kahan_residuals = []
tensor_index = 0
prev_step = 0
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step):
"""
Copies source into target using Kahan summations.
The part of the float32 weight that is lost on conversion to bfloat16 is sent
to the CPU until the next step, where it is re-added onto that step's updated
weight. This produces near float32-like weight behavior, although the copies
back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
"""
global kahan_residuals, tensor_index, prev_step
# Calculate the group index of the current residual Tensor. Tensors
# pass through this copy function in the same order at each step.
tensor_index += 1
if prev_step != step: # Starting new step?
prev_step = step
tensor_index = 0
# Initialize residuals to 0.0 for first step
if len(kahan_residuals) <= tensor_index:
kahan_residuals += [torch.zeros_like(source)]
# Bring the residual from the previous step back from the cpu device, and add it to the
# float32 weights of the current step
summed = kahan_residuals[tensor_index].detach().to(source.device) # Residual is float32 type
summed.add_(source)
# Mask off the lower 16 bits of the mantissa, adding 32768 in order to
# round-to-nearest when the lower bits are clipped off
summed_i32 = summed.view(dtype=torch.int32).detach().clone()
summed_quantized_i32 = summed_i32.add_(32768).bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
summed_quantized = summed_quantized_i32.view(dtype=torch.float32)
# The next residual is the difference between the quantized and unquantized weights
kahan_residuals[tensor_index] = summed.sub(summed_quantized).detach().to("cpu")
# Copy the quantized floats into the target tensor
target.copy_(summed_quantized)
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -108,7 +160,10 @@ def adafactor_step_param(self, p, group):
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state["step"])
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)