mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Added support for Kahan summation for Adafactor-optimized Flux FFT
This commit is contained in:
@@ -381,10 +381,25 @@ def train(args):
|
||||
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
|
||||
if args.optimizer_type == "adafactor" and args.full_bf16:
|
||||
logger.warning("Use of --blockwise_fused_optimizers with Adafactor optimizer prevents stochastic/kahan weight updates.")
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
# Pass any Kahan summation arg to the optimizer
|
||||
if args.kahan_summation:
|
||||
# Self check parameter compatibility
|
||||
if args.optimizer_type != "adafactor":
|
||||
logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.")
|
||||
if not args.full_bf16:
|
||||
logger.warning("Kahan summation require --full_bf16")
|
||||
if args.blockwise_fused_optimizers:
|
||||
logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\
|
||||
"Perhaps try --fused_backward_pass instead.")
|
||||
optimizer.use_kahan_summation = args.kahan_summation
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
@@ -815,6 +830,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kahan-summation",
|
||||
action="store_true",
|
||||
help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\
|
||||
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_latents_validity_check",
|
||||
action="store_true",
|
||||
@@ -838,13 +859,3 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -28,6 +28,58 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
||||
del result
|
||||
|
||||
|
||||
# Kahan summation for bfloat16
|
||||
# The implementation was provided by araleza.
|
||||
# Base on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
|
||||
|
||||
kahan_residuals = []
|
||||
tensor_index = 0
|
||||
prev_step = 0
|
||||
|
||||
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step):
|
||||
"""
|
||||
Copies source into target using Kahan summations.
|
||||
|
||||
The part of the float32 weight that is lost on conversion to bfloat16 is sent
|
||||
to the CPU until the next step, where it is re-added onto that step's updated
|
||||
weight. This produces near float32-like weight behavior, although the copies
|
||||
back and forth to main memory result in slower training steps.
|
||||
|
||||
Args:
|
||||
target: the target tensor with dtype=bfloat16
|
||||
source: the target tensor with dtype=float32
|
||||
"""
|
||||
global kahan_residuals, tensor_index, prev_step
|
||||
|
||||
# Calculate the group index of the current residual Tensor. Tensors
|
||||
# pass through this copy function in the same order at each step.
|
||||
tensor_index += 1
|
||||
if prev_step != step: # Starting new step?
|
||||
prev_step = step
|
||||
tensor_index = 0
|
||||
|
||||
# Initialize residuals to 0.0 for first step
|
||||
if len(kahan_residuals) <= tensor_index:
|
||||
kahan_residuals += [torch.zeros_like(source)]
|
||||
|
||||
# Bring the residual from the previous step back from the cpu device, and add it to the
|
||||
# float32 weights of the current step
|
||||
summed = kahan_residuals[tensor_index].detach().to(source.device) # Residual is float32 type
|
||||
summed.add_(source)
|
||||
|
||||
# Mask off the lower 16 bits of the mantissa, adding 32768 in order to
|
||||
# round-to-nearest when the lower bits are clipped off
|
||||
summed_i32 = summed.view(dtype=torch.int32).detach().clone()
|
||||
summed_quantized_i32 = summed_i32.add_(32768).bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||
summed_quantized = summed_quantized_i32.view(dtype=torch.float32)
|
||||
|
||||
# The next residual is the difference between the quantized and unquantized weights
|
||||
kahan_residuals[tensor_index] = summed.sub(summed_quantized).detach().to("cpu")
|
||||
|
||||
# Copy the quantized floats into the target tensor
|
||||
target.copy_(summed_quantized)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def adafactor_step_param(self, p, group):
|
||||
if p.grad is None:
|
||||
@@ -108,7 +160,10 @@ def adafactor_step_param(self, p, group):
|
||||
# p.copy_(p_data_fp32)
|
||||
|
||||
if p.dtype == torch.bfloat16:
|
||||
copy_stochastic_(p, p_data_fp32)
|
||||
if self.optimizer.use_kahan_summation:
|
||||
copy_kahan_(p, p_data_fp32, state["step"])
|
||||
else:
|
||||
copy_stochastic_(p, p_data_fp32)
|
||||
elif p.dtype == torch.float16:
|
||||
p.copy_(p_data_fp32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user