diff --git a/flux_train.py b/flux_train.py index 6f98adea..a631546b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -381,10 +381,25 @@ def train(args): raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") optimizer_train_fn = lambda: None # dummy function optimizer_eval_fn = lambda: None # dummy function + + if args.optimizer_type == "adafactor" and args.full_bf16: + logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/kahan weight updates.") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + # Pass any Kahan summation arg to the optimizer + if args.kahan_summation: + # Self check parameter compatibility + if args.optimizer_type != "adafactor": + logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.") + if not args.full_bf16: + logger.warning("Kahan summation require --full_bf16") + if args.blockwise_fused_optimizers: + logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\ + "Perhaps try --fused_backward_pass instead.") + optimizer.use_kahan_summation = args.kahan_summation + # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None @@ -815,6 +830,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) + parser.add_argument( + "--kahan-summation", + action="store_true", + help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\ + "bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します", + ) parser.add_argument( "--skip_latents_validity_check", action="store_true", @@ -838,13 +859,3 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - train_util.verify_command_line_training_args(args) - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index b5afa236..6cf1ac1a 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -28,6 +28,58 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): del result +# Kahan summation for bfloat16 +# The implementation was provided by araleza. +# Base on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 + +kahan_residuals = [] +tensor_index = 0 +prev_step = 0 + +def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step): + """ + Copies source into target using Kahan summations. + + The part of the float32 weight that is lost on conversion to bfloat16 is sent + to the CPU until the next step, where it is re-added onto that step's updated + weight. This produces near float32-like weight behavior, although the copies + back and forth to main memory result in slower training steps. + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + global kahan_residuals, tensor_index, prev_step + + # Calculate the group index of the current residual Tensor. Tensors + # pass through this copy function in the same order at each step. + tensor_index += 1 + if prev_step != step: # Starting new step? + prev_step = step + tensor_index = 0 + + # Initialize residuals to 0.0 for first step + if len(kahan_residuals) <= tensor_index: + kahan_residuals += [torch.zeros_like(source)] + + # Bring the residual from the previous step back from the cpu device, and add it to the + # float32 weights of the current step + summed = kahan_residuals[tensor_index].detach().to(source.device) # Residual is float32 type + summed.add_(source) + + # Mask off the lower 16 bits of the mantissa, adding 32768 in order to + # round-to-nearest when the lower bits are clipped off + summed_i32 = summed.view(dtype=torch.int32).detach().clone() + summed_quantized_i32 = summed_i32.add_(32768).bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + summed_quantized = summed_quantized_i32.view(dtype=torch.float32) + + # The next residual is the difference between the quantized and unquantized weights + kahan_residuals[tensor_index] = summed.sub(summed_quantized).detach().to("cpu") + + # Copy the quantized floats into the target tensor + target.copy_(summed_quantized) + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -108,7 +160,10 @@ def adafactor_step_param(self, p, group): # p.copy_(p_data_fp32) if p.dtype == torch.bfloat16: - copy_stochastic_(p, p_data_fp32) + if self.optimizer.use_kahan_summation: + copy_kahan_(p, p_data_fp32, state["step"]) + else: + copy_stochastic_(p, p_data_fp32) elif p.dtype == torch.float16: p.copy_(p_data_fp32)