Support for fused (N)AdamW + Kahan + momentum offloading FFT on a 5090.

This commit is contained in:
araleza
2025-08-24 16:00:38 +01:00
parent 4b12746d39
commit 225ea36285
4 changed files with 231 additions and 7 deletions

View File

@@ -330,6 +330,8 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
fused_optimizers_supported = ['adafactor', 'adamoffload', 'nadamoffload', 'adamwoffload', 'nadamwoffload']
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
@@ -381,10 +383,25 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
if (args.optimizer_type not in fused_optimizers_supported) and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/Kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type.lower() not in fused_optimizers_supported:
logger.warning("Kahan summation has been requested, but this is not supported by the selected optimizer.")
if not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
if args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but these are not compatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
optimizer.use_kahan_summation = args.kahan_summation
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -474,10 +491,18 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
# use fused optimizer for backward pass. Only some specific optimizers are supported.
import library.adafactor_fused
import library.adamw_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
if args.optimizer_type.lower() == "adafactor":
library.adafactor_fused.patch_adafactor_fused(optimizer)
elif args.optimizer_type.lower() == "adamoffload" or args.optimizer_type.lower() == "adamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, False)
elif args.optimizer_type.lower() == "nadamoffload" or args.optimizer_type.lower() == "nadamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, True) # Nesterov
else:
logger.error(f"Optimizer '{args.optimizer}' does not have a --fused_backward_pass implementation available")
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
@@ -816,6 +841,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float part lost during bf16 quantization, and re-adds it to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",

View File

@@ -28,6 +28,62 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result
# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update):
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
state: the optimizer state, used to store kahan residuals
update: the change in weights due to the gradient
"""
# Initialize residuals to 0 for first step
if state.get('kahan_residuals') is None:
state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16)
# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)
# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(state['kahan_residuals'])
# Reverse any rounding up during the cast to bf16 on the previous step
rounded_up = state['kahan_residuals'] >= 32768
source_i32[rounded_up] -= 65536
# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)
# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals'])
# Ensure rounding to bfloat16 matches expectations. These lines may not be
# necessary as target.copy_ should do this rounding anyway.
source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32. Leaves only upper bits in source
# Move the 16-bit Kahan bits from VRAM to main memory
state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu")
# Copy the quantized floats into the target tensor
target.copy_(source)
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -102,13 +158,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)

122
library/adamw_fused.py Normal file
View File

@@ -0,0 +1,122 @@
import math
import torch
from library.adafactor_fused import copy_stochastic_
from library.adafactor_fused import copy_kahan_
@torch.no_grad()
def adamw_offload_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("This (N)AdamW implementation does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
# State Initialization
if len(state) == 0:
state["step"] = 0
state['exp_avg'] = torch.zeros_like(p, dtype=torch.bfloat16)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.bfloat16)
state["step"] += 1
# NAdam
beta1, beta2 = group['betas']
eps = group['eps'] # 1e-8
weight_decay = group.get('weight_decay', 0.0)
# Bias correction terms
bias_correction1 = 1.0 - math.pow(beta1, state['step'])
bias_correction2 = 1.0 - math.pow(beta2, state['step'])
eps_p2: float = math.pow(eps, 2)
# Bring state back from CPU
state['exp_avg'] = state['exp_avg'] .to('cuda').to(dtype=torch.float32)
state['exp_avg_sq'] = state['exp_avg_sq'].to('cuda').to(dtype=torch.float32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Update biased first and second moment estimates
exp_avg .mul_(beta1).add_ (grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
# Compute bias-corrected second moment for denominator
exp_avg_sq_corrected = exp_avg_sq / bias_correction2
# Compute update based on whether Nesterov momentum (NAdam) is being used
if self.use_nesterov:
# The next step's bias correction for momentum is needed
bias_correction1_next = 1.0 - math.pow(beta1, state['step'] + 1)
# NAdam update: combines current gradient with momentum look-ahead
momentum_cache = exp_avg / bias_correction1_next
update = (beta1 * momentum_cache + (1.0 - beta1) * grad / bias_correction1) / (exp_avg_sq_corrected.sqrt() + eps)
else:
# Standard Adam update: use bias-corrected first moment directly
exp_avg_corrected = exp_avg / bias_correction1
update = exp_avg_corrected / (exp_avg_sq_corrected.sqrt() + eps)
lr: float = group['lr']
# Apply learning rate
update.mul_(lr)
# Apply weight decay
if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)
# Keep state on CPU
state['exp_avg'] = state['exp_avg'] .to(dtype=torch.bfloat16).to('cpu')
state['exp_avg_sq'] = state['exp_avg_sq'].to(dtype=torch.bfloat16).to('cpu')
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
if p.dtype == torch.bfloat16:
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@torch.no_grad()
def adamw_offload_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adamw_offload_step_param(self, p, group)
return loss
def patch_adamw_offload_fused(optimizer, use_nesterov):
optimizer.use_nesterov = use_nesterov
optimizer.step_param = adamw_offload_step_param.__get__(optimizer)
optimizer.step = adamw_offload_step.__get__(optimizer)

View File

@@ -4813,9 +4813,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
@@ -5059,6 +5056,18 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.Adam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW