mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Merge da17be080e into ae72efb92b
This commit is contained in:
@@ -4846,9 +4846,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
if args.fused_backward_pass:
|
||||
assert (
|
||||
optimizer_type == "Adafactor".lower()
|
||||
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
|
||||
assert (
|
||||
args.gradient_accumulation_steps == 1
|
||||
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
|
||||
@@ -5092,6 +5089,24 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
optimizer_class = transformers.optimization.Adafactor
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
|
||||
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = torch.optim.Adam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
|
||||
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.lower() == "adanoffload":
|
||||
logger.info(f"use AdanOffload optimizer | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = torch.optim.AdamW # todo: can't set beta3 here yet, need a custom Adan class
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "AdamW".lower():
|
||||
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
Reference in New Issue
Block a user