This commit is contained in:
araleza
2026-02-13 09:11:46 +07:00
committed by GitHub
6 changed files with 607 additions and 27 deletions

View File

@@ -4846,9 +4846,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
@@ -5092,6 +5089,24 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.Adam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adanoffload":
logger.info(f"use AdanOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # todo: can't set beta3 here yet, need a custom Adan class
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW