Support for fused (N)AdamW + Kahan + momentum offloading FFT on a 5090.

This commit is contained in:
araleza
2025-08-24 16:00:38 +01:00
parent 4b12746d39
commit 225ea36285
4 changed files with 231 additions and 7 deletions

View File

@@ -4813,9 +4813,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
@@ -5059,6 +5056,18 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.Adam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW