diff --git a/flux_train.py b/flux_train.py index 4aa67220..934bd6bd 100644 --- a/flux_train.py +++ b/flux_train.py @@ -381,10 +381,27 @@ def train(args): raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") optimizer_train_fn = lambda: None # dummy function optimizer_eval_fn = lambda: None # dummy function + + if args.optimizer_type == "adafactor" and args.full_bf16: + logger.warning("Use of --blockwise_fused_optimizer with Adafactor optimizer prevents stochastic/kahan weight updates.") else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + # Pass any Kahan summation arg to the optimizer + if args.kahan_summation: + # Self check parameter compatibility + if args.optimizer_type != "adafactor": + logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.") + elif not args.full_bf16: + logger.warning("Kahan summation requires --full_bf16") + elif args.blockwise_fused_optimizers: + logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\ + "Perhaps try --fused_backward_pass instead.") + else: + logger.info("Using Kahan summation") + optimizer.use_kahan_summation = args.kahan_summation + # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None @@ -816,6 +833,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", ) + parser.add_argument( + "--kahan_summation", + action="store_true", + help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\ + "bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します", + ) parser.add_argument( "--skip_latents_validity_check", action="store_true", diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index b5afa236..80ae4162 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -28,6 +28,60 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): del result +# Kahan summation for bfloat16 +# The implementation was provided by araleza. +# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192 + +def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update): + """ + Copies source into target using Kahan summation. + + The lower bits of the float32 weight that are lost on conversion to bfloat16 + are sent to the CPU until the next step, where they are re-added onto the weights + before adding the gradient update. This produces near float32-like weight behavior, + although the copies back and forth to main memory result in slower training steps. + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + state: the optimizer state, used to store kahan residuals + update: the change in weights due to the gradient + """ + + # Initialize residuals to 0 for first step + if state.get('kahan_residuals') is None: + state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16) + + # Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations + state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32) + + # Bring the previous step's lower bits of the weights back from the + # cpu device, and add them back to the weights of the current step. + source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32 + source_i32.add_(state['kahan_residuals']) + + # If the Kahan residual was >=0.5 then the cast to bf16 rounded up + rounded_up = state['kahan_residuals'] >= 32768 + source_i32[rounded_up] -= 65536 + + # Must add the gradient update after the bottom bits are restored in case + # the exponent is changed by the update, or the -65536 on the line above + # would drop the uint32 value below zero, which is invalid. + source.add_(-update) + + # Get the lower bits into the residual + torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals']) + + source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest + source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source + + # Move the 16-bit Kahan bits from VRAM to main memory + state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu") + + # Copy the quantized floats into the target tensor + target.copy_(source) + + @torch.no_grad() def adafactor_step_param(self, p, group): if p.grad is None: @@ -102,13 +156,19 @@ def adafactor_step_param(self, p, group): if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) - p_data_fp32.add_(-update) + # Add on gradient update, but not if using kahan summation as the bottom + # bits must be restored first. (This update occurs in copy_kahan_() instead) + if not self.optimizer.use_kahan_summation: + p_data_fp32.add_(-update) # if p.dtype in {torch.float16, torch.bfloat16}: # p.copy_(p_data_fp32) if p.dtype == torch.bfloat16: - copy_stochastic_(p, p_data_fp32) + if self.optimizer.use_kahan_summation: + copy_kahan_(p, p_data_fp32, state, update) + else: + copy_stochastic_(p, p_data_fp32) elif p.dtype == torch.float16: p.copy_(p_data_fp32)