This commit is contained in:
araleza
2026-04-03 02:09:52 +08:00
committed by GitHub
2 changed files with 85 additions and 2 deletions

View File

@@ -381,10 +381,27 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
if args.optimizer_type == "adafactor" and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizer with Adafactor optimizer prevents stochastic/kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type != "adafactor":
logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.")
elif not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
elif args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
else:
logger.info("Using Kahan summation")
optimizer.use_kahan_summation = args.kahan_summation
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -816,6 +833,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",

View File

@@ -28,6 +28,60 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result
# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update):
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
state: the optimizer state, used to store kahan residuals
update: the change in weights due to the gradient
"""
# Initialize residuals to 0 for first step
if state.get('kahan_residuals') is None:
state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16)
# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)
# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(state['kahan_residuals'])
# If the Kahan residual was >=0.5 then the cast to bf16 rounded up
rounded_up = state['kahan_residuals'] >= 32768
source_i32[rounded_up] -= 65536
# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)
# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals'])
source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source
# Move the 16-bit Kahan bits from VRAM to main memory
state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu")
# Copy the quantized floats into the target tensor
target.copy_(source)
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -102,13 +156,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)