mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
adding torchao and fixes
This commit is contained in:
@@ -3599,7 +3599,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
"Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, "
|
||||
"DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, "
|
||||
"AdaFactor. "
|
||||
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit', or 'pytorch_optimizer.CAME'.",
|
||||
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit', 'bitsandbytes.optim.PagedAdEMAMix8bit', 'pytorch_optimizer.CAME', or 'torchao.prototype.low_bit_optim.adam.AdamW4bit'.",
|
||||
)
|
||||
|
||||
# backward compatibility
|
||||
@@ -4824,15 +4824,37 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("4bit") or optimizer_type.endswith("Fp8"):
|
||||
# https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim
|
||||
try:
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamWFp8
|
||||
except ImportError:
|
||||
raise ImportError("No torchao / torchaoがインストールされていないようです")
|
||||
|
||||
if optimizer_type == "AdamW4bit".lower():
|
||||
logger.info(f"use 4-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = AdamW4bit
|
||||
optimizer = optimizer_class(trainable_params, lr=torch.tensor(lr), **optimizer_kwargs)
|
||||
elif optimizer_type == "AdamWFp8".lower():
|
||||
logger.info(f"use AdamW Fp8 optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = AdamWFp8
|
||||
optimizer = optimizer_class(trainable_params, lr=torch.tensor(lr), **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("8bit".lower()):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||
|
||||
#try:
|
||||
# from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
#except ImportError:
|
||||
# raise ImportError("No torchao / torchaoがインストールされていないようです")
|
||||
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
#optimizer_class = AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
@@ -5014,6 +5036,8 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
if args.lr_scheduler != "adafactor":
|
||||
logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
||||
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
||||
#logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
||||
#args.lr_scheduler = f"adafactor:0" # ちょっと微妙だけど
|
||||
|
||||
lr = None
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user