From 6517b2b8384e57c2599d25fadd74391acde91f42 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 23 Jul 2025 14:34:32 +0100 Subject: [PATCH 1/8] Added support for Kahan summation for Adafactor-optimized Flux FFT --- flux_train.py | 31 ++++++++++++++------- library/adafactor_fused.py | 57 +++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 11 deletions(-) diff --git a/flux_train.py b/flux_train.py index 6f98adea..a631546b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -381,10 +381,25 @@ def train(args): raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") optimizer_train_fn = lambda: None # dummy function optimizer_eval_fn = lambda: None # dummy function + + if args.optimizer_type == "adafactor" and args.full_bf16: + logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/kahan weight updates.") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + # Pass any Kahan summation arg to the optimizer + if args.kahan_summation: + # Self check parameter compatibility + if args.optimizer_type != "adafactor": + logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.") + if not args.full_bf16: + logger.warning("Kahan summation require --full_bf16") + if args.blockwise_fused_optimizers: + logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\ + "Perhaps try --fused_backward_pass instead.") + optimizer.use_kahan_summation = args.kahan_summation + # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None @@ -815,6 +830,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) + parser.add_argument( + "--kahan-summation", + action="store_true", + help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\ + "bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します", + ) parser.add_argument( "--skip_latents_validity_check", action="store_true", @@ -838,13 +859,3 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - train_util.verify_command_line_training_args(args) - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index b5afa236..6cf1ac1a 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -28,6 +28,58 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): del result +# Kahan summation for bfloat16 +# The implementation was provided by araleza. +# Base on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 + +kahan_residuals = [] +tensor_index = 0 +prev_step = 0 + +def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step): + """ + Copies source into target using Kahan summations. + + The part of the float32 weight that is lost on conversion to bfloat16 is sent + to the CPU until the next step, where it is re-added onto that step's updated + weight. This produces near float32-like weight behavior, although the copies + back and forth to main memory result in slower training steps. + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + global kahan_residuals, tensor_index, prev_step + + # Calculate the group index of the current residual Tensor. Tensors + # pass through this copy function in the same order at each step. + tensor_index += 1 + if prev_step != step: # Starting new step? + prev_step = step + tensor_index = 0 + + # Initialize residuals to 0.0 for first step + if len(kahan_residuals) <= tensor_index: + kahan_residuals += [torch.zeros_like(source)] + + # Bring the residual from the previous step back from the cpu device, and add it to the + # float32 weights of the current step + summed = kahan_residuals[tensor_index].detach().to(source.device) # Residual is float32 type + summed.add_(source) + + # Mask off the lower 16 bits of the mantissa, adding 32768 in order to + # round-to-nearest when the lower bits are clipped off + summed_i32 = summed.view(dtype=torch.int32).detach().clone() + summed_quantized_i32 = summed_i32.add_(32768).bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + summed_quantized = summed_quantized_i32.view(dtype=torch.float32) + + # The next residual is the difference between the quantized and unquantized weights + kahan_residuals[tensor_index] = summed.sub(summed_quantized).detach().to("cpu") + + # Copy the quantized floats into the target tensor + target.copy_(summed_quantized) + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -108,7 +160,10 @@ def adafactor_step_param(self, p, group): # p.copy_(p_data_fp32) if p.dtype == torch.bfloat16: - copy_stochastic_(p, p_data_fp32) + if self.optimizer.use_kahan_summation: + copy_kahan_(p, p_data_fp32, state["step"]) + else: + copy_stochastic_(p, p_data_fp32) elif p.dtype == torch.float16: p.copy_(p_data_fp32) From da6416a2fc9b9f0e09754babb48ab44116f1212b Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 23 Jul 2025 15:08:24 +0100 Subject: [PATCH 2/8] Restoring the deleted __main__ function and fixing a warning typo --- flux_train.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/flux_train.py b/flux_train.py index a631546b..d60fbb0e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -394,7 +394,7 @@ def train(args): if args.optimizer_type != "adafactor": logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.") if not args.full_bf16: - logger.warning("Kahan summation require --full_bf16") + logger.warning("Kahan summation requires --full_bf16") if args.blockwise_fused_optimizers: logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\ "Perhaps try --fused_backward_pass instead.") @@ -859,3 +859,13 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) From bb7750fbcafaccafe99883cb549554079009f9c8 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 23 Jul 2025 15:10:57 +0100 Subject: [PATCH 3/8] Fixed typo in comment --- library/adafactor_fused.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index 6cf1ac1a..47a4e374 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -30,7 +30,7 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): # Kahan summation for bfloat16 # The implementation was provided by araleza. -# Base on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 +# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 kahan_residuals = [] tensor_index = 0 From acb4cf32e882d3b241ecb8ed84ee4f17a8fa5b78 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 23 Jul 2025 18:25:07 +0100 Subject: [PATCH 4/8] Fixed a warning typo, and changed --kahan-summation to --kahan_summation --- flux_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index d60fbb0e..fad58d2b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -383,7 +383,7 @@ def train(args): optimizer_eval_fn = lambda: None # dummy function if args.optimizer_type == "adafactor" and args.full_bf16: - logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/kahan weight updates.") + logger.warning("Use of --blockwise_fused_optimizer with Adafactor optimizer prevents stochastic/kahan weight updates.") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) @@ -831,7 +831,7 @@ def setup_parser() -> argparse.ArgumentParser: help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) parser.add_argument( - "--kahan-summation", + "--kahan_summation", action="store_true", help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\ "bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します", From 3f0230a2863c4f708e1c83f8832920a85fa66137 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Tue, 29 Jul 2025 10:05:06 +0100 Subject: [PATCH 5/8] Now sending int16s instead of f32s to cpu device; faster and maybe more accurate --- library/adafactor_fused.py | 59 ++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index 47a4e374..d1d4e79d 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -36,18 +36,20 @@ kahan_residuals = [] tensor_index = 0 prev_step = 0 -def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step): +def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): """ - Copies source into target using Kahan summations. + Copies source into target using Kahan summation. - The part of the float32 weight that is lost on conversion to bfloat16 is sent - to the CPU until the next step, where it is re-added onto that step's updated - weight. This produces near float32-like weight behavior, although the copies - back and forth to main memory result in slower training steps. + The lower bits of the float32 weight that are lost on conversion to bfloat16 + are sent to the CPU until the next step, where they are re-added onto the weights + before adding the gradient update. This produces near float32-like weight behavior, + although the copies back and forth to main memory result in slower training steps. Args: target: the target tensor with dtype=bfloat16 source: the target tensor with dtype=float32 + step: the global training step count + update: the change in weights due to the gradient """ global kahan_residuals, tensor_index, prev_step @@ -58,26 +60,38 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step): prev_step = step tensor_index = 0 - # Initialize residuals to 0.0 for first step + # Initialize residuals to 0 for first step if len(kahan_residuals) <= tensor_index: - kahan_residuals += [torch.zeros_like(source)] + kahan_residuals += [torch.zeros_like(source, dtype=torch.int16)] + + # Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations + kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(source.device).to(dtype=torch.int32) + + # Bring the previous step's lower bits of the weights back from the + # cpu device, and add them back to the weights of the current step. + source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32 + source_i32.add_(kahan_residuals[tensor_index]) - # Bring the residual from the previous step back from the cpu device, and add it to the - # float32 weights of the current step - summed = kahan_residuals[tensor_index].detach().to(source.device) # Residual is float32 type - summed.add_(source) + # If the Kahan residual was >=0.5 then the cast to bf16 rounded up + rounded_up = kahan_residuals[tensor_index] >= 32768 + source_i32[rounded_up] -= 65536 - # Mask off the lower 16 bits of the mantissa, adding 32768 in order to - # round-to-nearest when the lower bits are clipped off - summed_i32 = summed.view(dtype=torch.int32).detach().clone() - summed_quantized_i32 = summed_i32.add_(32768).bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 - summed_quantized = summed_quantized_i32.view(dtype=torch.float32) + # Must add the gradient update after the bottom bits are restored in case + # the exponent is changed by the update, or the -65536 on the line above + # would drop the uint32 value below zero, which is invalid. + source.add_(-update) - # The next residual is the difference between the quantized and unquantized weights - kahan_residuals[tensor_index] = summed.sub(summed_quantized).detach().to("cpu") + # Get the lower bits into the residual + torch.bitwise_and(source_i32, 0x0000FFFF, out=kahan_residuals[tensor_index]) + + source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest + source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source + + # Move the 16-bit Kahan bits from VRAM to main memory + kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(dtype=torch.uint16).to("cpu") # Copy the quantized floats into the target tensor - target.copy_(summed_quantized) + target.copy_(source) @torch.no_grad() @@ -154,7 +168,10 @@ def adafactor_step_param(self, p, group): if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - p_data_fp32.add_(-update) + # Add on gradient update, but not if using kahan summation as the bottom + # bits must be restored first. (This update occurs in copy_kahan_() instead) + if not self.optimizer.use_kahan_summation: + p_data_fp32.add_(-update) # if p.dtype in {torch.float16, torch.bfloat16}: # p.copy_(p_data_fp32) From 648994271ea6976fa69d0659e01788e567c58b98 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Tue, 29 Jul 2025 10:28:26 +0100 Subject: [PATCH 6/8] Added log output message to show that Kahan summation is being used --- flux_train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flux_train.py b/flux_train.py index fad58d2b..563e845e 100644 --- a/flux_train.py +++ b/flux_train.py @@ -393,11 +393,13 @@ def train(args): # Self check parameter compatibility if args.optimizer_type != "adafactor": logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.") - if not args.full_bf16: + elif not args.full_bf16: logger.warning("Kahan summation requires --full_bf16") - if args.blockwise_fused_optimizers: + elif args.blockwise_fused_optimizers: logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\ "Perhaps try --fused_backward_pass instead.") + else: + logger.info("Using Kahan summation") optimizer.use_kahan_summation = args.kahan_summation # prepare dataloader From cd239f0fa93939bfd27a941ef4d9080232564a6b Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:42:15 +0100 Subject: [PATCH 7/8] Moved kahan state from file globals to optimizer state variables --- library/adafactor_fused.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index d1d4e79d..c59f668c 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -32,11 +32,7 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): # The implementation was provided by araleza. # Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 -kahan_residuals = [] -tensor_index = 0 -prev_step = 0 - -def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): +def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update): """ Copies source into target using Kahan summation. @@ -48,32 +44,26 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): Args: target: the target tensor with dtype=bfloat16 source: the target tensor with dtype=float32 - step: the global training step count + state: the optimizer state, used to store kahan residuals update: the change in weights due to the gradient """ - global kahan_residuals, tensor_index, prev_step - - # Calculate the group index of the current residual Tensor. Tensors - # pass through this copy function in the same order at each step. - tensor_index += 1 - if prev_step != step: # Starting new step? - prev_step = step - tensor_index = 0 # Initialize residuals to 0 for first step - if len(kahan_residuals) <= tensor_index: - kahan_residuals += [torch.zeros_like(source, dtype=torch.int16)] + if state.get('kahan_residuals') is None: + state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16) + else: + pass # Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations - kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(source.device).to(dtype=torch.int32) + state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32) # Bring the previous step's lower bits of the weights back from the # cpu device, and add them back to the weights of the current step. source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32 - source_i32.add_(kahan_residuals[tensor_index]) + source_i32.add_(state['kahan_residuals']) # If the Kahan residual was >=0.5 then the cast to bf16 rounded up - rounded_up = kahan_residuals[tensor_index] >= 32768 + rounded_up = state['kahan_residuals'] >= 32768 source_i32[rounded_up] -= 65536 # Must add the gradient update after the bottom bits are restored in case @@ -82,13 +72,13 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): source.add_(-update) # Get the lower bits into the residual - torch.bitwise_and(source_i32, 0x0000FFFF, out=kahan_residuals[tensor_index]) + torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals']) source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source # Move the 16-bit Kahan bits from VRAM to main memory - kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(dtype=torch.uint16).to("cpu") + state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu") # Copy the quantized floats into the target tensor target.copy_(source) @@ -178,7 +168,7 @@ def adafactor_step_param(self, p, group): if p.dtype == torch.bfloat16: if self.optimizer.use_kahan_summation: - copy_kahan_(p, p_data_fp32, state["step"]) + copy_kahan_(p, p_data_fp32, state, update) else: copy_stochastic_(p, p_data_fp32) elif p.dtype == torch.float16: From ac8ae581dbc53cb66760886351bfb2af1aeb4ea5 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:49:37 +0100 Subject: [PATCH 8/8] Removed some no-effect lines used for a debug breakpoint --- library/adafactor_fused.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index c59f668c..80ae4162 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -51,8 +51,6 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update): # Initialize residuals to 0 for first step if state.get('kahan_residuals') is None: state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16) - else: - pass # Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)