Add Adafactor optimzier

This commit is contained in:
Kohya S
2023-02-22 21:09:47 +09:00
parent 663aad2b0d
commit 9ab964d0b8
5 changed files with 181 additions and 80 deletions

View File

@@ -1,6 +1,7 @@
# common functions for training
import argparse
import importlib
import json
import shutil
import time
@@ -21,6 +22,7 @@ import torch
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
import transformers
import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import DDPMScheduler, StableDiffusionPipeline
@@ -1371,28 +1373,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW",
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
# backward compatibility
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う dadaptation のインストールが必要)")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
help="Momentum value for optimizers for SGD optimizers")
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
help="Weight decay for optimizers")
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
help="beta1 parameter for Adam optimizers")
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
help="beta2 parameter for Adam optimizers")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
parser.add_argument("--lr_scheduler_power", type=float, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1525,18 +1528,37 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます")
optimizer_type = "AdamW8bit"
elif args.use_lion_optimizer:
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます")
optimizer_type = "Lion"
optimizer_type = optimizer_type.lower()
betas = (args.optimizer_beta1, args.optimizer_beta2)
weight_decay = args.optimizer_weight_decay
momentum = args.optimizer_momentum
# 引数を分解するboolとfloat、tupleのみ対応
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args:
key, value = arg.split('=')
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = (value[i].lower() == "true")
else:
value[i] = float(value[i])
if len(value) == 1:
value = value[0]
else:
value = tuple(value)
optimizer_kwargs[key] = value
print("optkwargs:", optimizer_kwargs)
lr = args.learning_rate
if optimizer_type == "AdamW8bit".lower():
@@ -1544,53 +1566,128 @@ def get_optimizer(args, trainable_params):
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
optimizer_class = bnb.optim.AdamW8bit
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov8bit".lower():
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
optimizer_kwargs["momentum"] = 0.9
optimizer_class = bnb.optim.SGD8bit
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type == "Lion".lower():
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}")
print(f"use Lion optimizer | {optimizer_kwargs}")
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov".lower():
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
optimizer_kwargs["momentum"] = 0.9
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type == "DAdaptation".lower():
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = dadaptation.DAdaptAdam
if lr <= 0.1:
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
min_lr = lr
if type(trainable_params) == list and type(trainable_params[0]) == dict:
for group in trainable_params:
min_lr = min(min_lr, group.get("lr", lr))
if min_lr <= 0.1:
print(
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
print('recommend option: lr=1.0 / 推奨は1.0です')
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
optimizer_class = dadaptation.DAdaptAdam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "Adafactor".lower():
# 引数を確認して適宜補正する
if "relative_step" not in optimizer_kwargs:
optimizer_kwargs["relative_step"] = True # default
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
optimizer_kwargs["relative_step"] = True
print(f"use Adafactor optimizer | {optimizer_kwargs}")
if optimizer_kwargs["relative_step"]:
print(f"relative_step is true / relative_stepがtrueです")
if lr != 0.0:
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
args.learning_rate = None
# trainable_paramsがgroupだった時の処理lrを削除する
if type(trainable_params) == list and type(trainable_params[0]) == dict:
has_group_lr = False
for group in trainable_params:
p = group.pop("lr", None)
has_group_lr = has_group_lr or (p is not None)
if has_group_lr:
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
args.unet_lr = None
args.text_encoder_lr = None
if args.lr_scheduler != "adafactor":
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
lr = None
else:
if args.max_grad_norm != 0.0:
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
if args.lr_scheduler != "constant_with_warmup":
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
print(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
else:
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ微妙
print(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
return optimizer_name, optimizer
return optimizer_name, optimizer_args, optimizer
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
@@ -1627,6 +1724,12 @@ def get_scheduler_fix(
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
if name.startswith("adafactor"):
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(':')[1])
# print("adafactor scheduler init lr", initial_lr)
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
@@ -1744,13 +1847,19 @@ def prepare_dtype(args: argparse.Namespace):
def load_target_model(args: argparse.Namespace, weight_dtype):
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
except EnvironmentError as ex:
print(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet