This commit is contained in:
araleza
2025-09-26 01:40:27 +08:00
committed by GitHub
4 changed files with 315 additions and 7 deletions

View File

@@ -330,6 +330,8 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
fused_optimizers_supported = ['adafactor', 'adamoffload', 'nadamoffload', 'adamwoffload', 'nadamwoffload']
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
@@ -381,10 +383,25 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
if (args.optimizer_type in fused_optimizers_supported) and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizers is preventing stochastic/Kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type.lower() not in fused_optimizers_supported:
logger.warning("Kahan summation has been requested, but this is not supported by the selected optimizer.")
if not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
if args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but these are not compatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
optimizer.use_kahan_summation = args.kahan_summation
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -474,10 +491,18 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
# use fused optimizer for backward pass. Only some specific optimizers are supported.
import library.adafactor_fused
import library.adamw_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
if args.optimizer_type.lower() == "adafactor":
library.adafactor_fused.patch_adafactor_fused(optimizer)
elif args.optimizer_type.lower() == "adamoffload" or args.optimizer_type.lower() == "adamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, False)
elif args.optimizer_type.lower() == "nadamoffload" or args.optimizer_type.lower() == "nadamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, True) # Nesterov
else:
logger.error(f"Optimizer '{args.optimizer_type}' does not have a --fused_backward_pass implementation available")
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
@@ -816,6 +841,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float part lost during bf16 quantization, and re-adds it to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",

View File

@@ -28,6 +28,62 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result
# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update):
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
state: the optimizer state, used to store kahan residuals
update: the change in weights due to the gradient
"""
# Initialize residuals to 0 for first step
if state.get('kahan_residuals') is None:
state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16)
# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)
# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(state['kahan_residuals'])
# Reverse any rounding up during the cast to bf16 on the previous step
rounded_up = state['kahan_residuals'] >= 32768
source_i32[rounded_up] -= 65536
# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)
# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals'])
# Ensure rounding to bfloat16 matches expectations. These lines may not be
# necessary as target.copy_ should do this rounding anyway.
source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32. Leaves only upper bits in source
# Move the 16-bit Kahan bits from VRAM to main memory
state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu")
# Copy the quantized floats into the target tensor
target.copy_(source)
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -102,13 +158,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)

206
library/adamw_fused.py Normal file
View File

@@ -0,0 +1,206 @@
import math
import torch
from library.adafactor_fused import copy_stochastic_
from library.adafactor_fused import copy_kahan_
@torch.no_grad()
def adamw_offload_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("This (N)AdamW implementation does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
# Tensors with few elements may be more sensitive to quantization
# errors, so keep them in float32
high_quality = torch.numel(p) <= 4096
# State Initialization
if len(state) == 0:
state["step"] = 0
if high_quality:
# Exponential averages stored in f32 format
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float32)
else:
# Exponential averages stored in u16 format
state['exp_avg'] = torch.zeros_like(p, dtype=torch.uint16)
state['exp_avg_min'] = 0.0
state['exp_avg_max'] = 1.0
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.uint16)
state['exp_avg_sq_min'] = 0.0
state['exp_avg_sq_max'] = 1.0
state["step"] += 1
# NAdam
beta1, beta2 = group['betas']
eps = group['eps'] # 1e-8
weight_decay = group.get('weight_decay', 0.0)
# Bias correction terms
bias_correction1 = 1.0 - math.pow(beta1, state['step'])
bias_correction2 = 1.0 - math.pow(beta2, state['step'])
eps_p2: float = math.pow(eps, 2)
# Bring state back from CPU
if high_quality:
# These exponential averages are already in float32 format
state['exp_avg'] = state['exp_avg'] .to(p.device)
state['exp_avg_sq'] = state['exp_avg_sq'].to(p.device)
else:
# Unpack these exponential averages from uint16 format
# A power function was applied to the tensor values, as they are usually
# distributed in an exponential fashion. After the power function was applied,
# the min and max of the results were noted, and then the values were scaled
# to the 0-65535 range for storage. This process is reversed here.
u16power = 8.0 # This value worked acceptably in testing to spread the values more evenly
exp_avg_min = state['exp_avg_min']
exp_avg_max = state['exp_avg_max']
exp_avg_sq_min = state['exp_avg_sq_min']
exp_avg_sq_max = state['exp_avg_sq_max']
uint16_recreate_a = state['exp_avg'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_max - exp_avg_min) + exp_avg_min
state['exp_avg'] = torch.pow(torch.abs(uint16_recreate_a), u16power) * torch.sgn(uint16_recreate_a)
del uint16_recreate_a
uint16_recreate_a_sq = state['exp_avg_sq'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_sq_max - exp_avg_sq_min) + exp_avg_sq_min
state['exp_avg_sq'] = torch.pow(torch.abs(uint16_recreate_a_sq), u16power) * torch.sgn(uint16_recreate_a_sq)
del uint16_recreate_a_sq
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Update biased first and second moment estimates
exp_avg .mul_(beta1).add_ (grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
# Compute bias-corrected second moment for denominator
exp_avg_sq_corrected = exp_avg_sq / bias_correction2
# Compute update based on whether Nesterov momentum (NAdam) is being used
if self.use_nesterov:
# The next step's bias correction for momentum is needed
bias_correction1_next = 1.0 - math.pow(beta1, state['step'] + 1)
# NAdam update: combines current gradient with momentum look-ahead
momentum_cache = exp_avg / bias_correction1_next
update = (beta1 * momentum_cache + (1.0 - beta1) * grad / bias_correction1) / (exp_avg_sq_corrected.sqrt() + eps)
else:
# Standard Adam update: use bias-corrected first moment directly
exp_avg_corrected = exp_avg / bias_correction1
update = exp_avg_corrected / (exp_avg_sq_corrected.sqrt() + eps)
lr: float = group['lr']
# Apply learning rate
update.mul_(lr)
# Apply weight decay
if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)
if high_quality:
# These are kept in f32 format between steps
state['exp_avg'] = state['exp_avg'].to('cpu')
state['exp_avg_sq'] = state['exp_avg_sq'].to('cpu')
else:
# Compress the exp_avg and exp_avg_sq tensors to cut their size down
# from 32 bit to 16 bit.
#
# A power function is applied to try to linearize the tensor values, as
# they are usually distributed in an exponential fashion. It would have
# been preferable to use a log() function, but the input values can be
# negative, so a pow() function is used instead. The 1/16th power was
# chosen fairly arbitrarily, but seemed to distribute the values fairly
# reasonably in some simple tests.
#
# After the power function is applied, the min and max of the resulting
# values are stored, and the values are then scaled to the 0-65535 range
# for storage.
#
# Doing this instead of storing these values as bf16 reduced the L1
# error between the stored values and the true f32 values by around 90%,
# with a notable increase in output image quality.
log_exp_avg = torch.pow(torch.abs(state['exp_avg']), 1.0 / u16power) * torch.sgn(state['exp_avg'])
exp_avg_min = torch.min(log_exp_avg)
exp_avg_max = torch.max(log_exp_avg)
state['exp_avg_min'] = exp_avg_min
state['exp_avg_max'] = exp_avg_max
normalized = (log_exp_avg - exp_avg_min) / (exp_avg_max - exp_avg_min)
del log_exp_avg
state['exp_avg'] = (normalized * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
log_exp_avg_sq = torch.pow(torch.abs(state['exp_avg_sq']), 1.0 / u16power) * torch.sgn(state['exp_avg_sq'])
exp_avg_sq_min = torch.min(log_exp_avg_sq)
exp_avg_sq_max = torch.max(log_exp_avg_sq)
state['exp_avg_sq_min'] = exp_avg_sq_min
state['exp_avg_sq_max'] = exp_avg_sq_max
normalized_sq = (log_exp_avg_sq - exp_avg_sq_min) / (exp_avg_sq_max - exp_avg_sq_min)
del log_exp_avg_sq
state['exp_avg_sq'] = (normalized_sq * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
if p.dtype == torch.bfloat16:
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@torch.no_grad()
def adamw_offload_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adamw_offload_step_param(self, p, group)
return loss
def patch_adamw_offload_fused(optimizer, use_nesterov):
optimizer.use_nesterov = use_nesterov
optimizer.step_param = adamw_offload_step_param.__get__(optimizer)
optimizer.step = adamw_offload_step.__get__(optimizer)

View File

@@ -4884,9 +4884,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
@@ -5130,6 +5127,18 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.Adam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW