adding torchao and fixes

This commit is contained in:
Darren Laurie
2025-03-17 01:37:23 +08:00
parent 4a3ced5fb9
commit d35c51a59e
3 changed files with 34 additions and 4 deletions

View File

@@ -3599,7 +3599,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, "
"DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, "
"AdaFactor. "
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit', or 'pytorch_optimizer.CAME'.",
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit', 'bitsandbytes.optim.PagedAdEMAMix8bit', 'pytorch_optimizer.CAME', or 'torchao.prototype.low_bit_optim.adam.AdamW4bit'.",
)
# backward compatibility
@@ -4824,15 +4824,37 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("4bit") or optimizer_type.endswith("Fp8"):
# https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim
try:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamWFp8
except ImportError:
raise ImportError("No torchao / torchaoがインストールされていないようです")
if optimizer_type == "AdamW4bit".lower():
logger.info(f"use 4-bit AdamW optimizer | {optimizer_kwargs}")
optimizer_class = AdamW4bit
optimizer = optimizer_class(trainable_params, lr=torch.tensor(lr), **optimizer_kwargs)
elif optimizer_type == "AdamWFp8".lower():
logger.info(f"use AdamW Fp8 optimizer | {optimizer_kwargs}")
optimizer_class = AdamWFp8
optimizer = optimizer_class(trainable_params, lr=torch.tensor(lr), **optimizer_kwargs)
elif optimizer_type.endswith("8bit".lower()):
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
#try:
# from torchao.prototype.low_bit_optim import AdamW8bit
#except ImportError:
# raise ImportError("No torchao / torchaoがインストールされていないようです")
if optimizer_type == "AdamW8bit".lower():
logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
optimizer_class = bnb.optim.AdamW8bit
#optimizer_class = AdamW8bit
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov8bit".lower():
@@ -5014,6 +5036,8 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
if args.lr_scheduler != "adafactor":
logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
#logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
#args.lr_scheduler = f"adafactor:0" # ちょっと微妙だけど
lr = None
else: