From cd239f0fa93939bfd27a941ef4d9080232564a6b Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:42:15 +0100 Subject: [PATCH] Moved kahan state from file globals to optimizer state variables --- library/adafactor_fused.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index d1d4e79d..c59f668c 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -32,11 +32,7 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): # The implementation was provided by araleza. # Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 -kahan_residuals = [] -tensor_index = 0 -prev_step = 0 - -def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): +def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update): """ Copies source into target using Kahan summation. @@ -48,32 +44,26 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): Args: target: the target tensor with dtype=bfloat16 source: the target tensor with dtype=float32 - step: the global training step count + state: the optimizer state, used to store kahan residuals update: the change in weights due to the gradient """ - global kahan_residuals, tensor_index, prev_step - - # Calculate the group index of the current residual Tensor. Tensors - # pass through this copy function in the same order at each step. - tensor_index += 1 - if prev_step != step: # Starting new step? - prev_step = step - tensor_index = 0 # Initialize residuals to 0 for first step - if len(kahan_residuals) <= tensor_index: - kahan_residuals += [torch.zeros_like(source, dtype=torch.int16)] + if state.get('kahan_residuals') is None: + state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16) + else: + pass # Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations - kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(source.device).to(dtype=torch.int32) + state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32) # Bring the previous step's lower bits of the weights back from the # cpu device, and add them back to the weights of the current step. source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32 - source_i32.add_(kahan_residuals[tensor_index]) + source_i32.add_(state['kahan_residuals']) # If the Kahan residual was >=0.5 then the cast to bf16 rounded up - rounded_up = kahan_residuals[tensor_index] >= 32768 + rounded_up = state['kahan_residuals'] >= 32768 source_i32[rounded_up] -= 65536 # Must add the gradient update after the bottom bits are restored in case @@ -82,13 +72,13 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update): source.add_(-update) # Get the lower bits into the residual - torch.bitwise_and(source_i32, 0x0000FFFF, out=kahan_residuals[tensor_index]) + torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals']) source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source # Move the 16-bit Kahan bits from VRAM to main memory - kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(dtype=torch.uint16).to("cpu") + state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu") # Copy the quantized floats into the target tensor target.copy_(source) @@ -178,7 +168,7 @@ def adafactor_step_param(self, p, group): if p.dtype == torch.bfloat16: if self.optimizer.use_kahan_summation: - copy_kahan_(p, p_data_fp32, state["step"]) + copy_kahan_(p, p_data_fp32, state, update) else: copy_stochastic_(p, p_data_fp32) elif p.dtype == torch.float16: