mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Support for fused (N)AdamW + Kahan + momentum offloading FFT on a 5090.
This commit is contained in:
@@ -4813,9 +4813,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
if args.fused_backward_pass:
|
||||
assert (
|
||||
optimizer_type == "Adafactor".lower()
|
||||
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
|
||||
assert (
|
||||
args.gradient_accumulation_steps == 1
|
||||
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
|
||||
@@ -5059,6 +5056,18 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
optimizer_class = transformers.optimization.Adafactor
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
|
||||
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = torch.optim.Adam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
|
||||
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "AdamW".lower():
|
||||
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
Reference in New Issue
Block a user