mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add Adafactor optimzier
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# common functions for training
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
@@ -21,6 +22,7 @@ import torch
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import transformers
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
@@ -1371,28 +1373,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
||||
|
||||
# backward compatibility
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
|
||||
help="Momentum value for optimizers for SGD optimizers")
|
||||
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
|
||||
help="Weight decay for optimizers")
|
||||
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
|
||||
help="beta1 parameter for Adam optimizers")
|
||||
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
|
||||
help="beta2 parameter for Adam optimizers")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
||||
|
||||
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
@@ -1525,18 +1528,37 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます")
|
||||
optimizer_type = "AdamW8bit"
|
||||
elif args.use_lion_optimizer:
|
||||
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます")
|
||||
optimizer_type = "Lion"
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
betas = (args.optimizer_beta1, args.optimizer_beta2)
|
||||
weight_decay = args.optimizer_weight_decay
|
||||
momentum = args.optimizer_momentum
|
||||
# 引数を分解する:boolとfloat、tupleのみ対応
|
||||
optimizer_kwargs = {}
|
||||
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
||||
for arg in args.optimizer_args:
|
||||
key, value = arg.split('=')
|
||||
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = (value[i].lower() == "true")
|
||||
else:
|
||||
value[i] = float(value[i])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
else:
|
||||
value = tuple(value)
|
||||
|
||||
optimizer_kwargs[key] = value
|
||||
print("optkwargs:", optimizer_kwargs)
|
||||
|
||||
lr = args.learning_rate
|
||||
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
@@ -1544,53 +1566,128 @@ def get_optimizer(args, trainable_params):
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = bnb.optim.SGD8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Lion".lower():
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
print(f"use Lion optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov".lower():
|
||||
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "DAdaptation".lower():
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
if lr <= 0.1:
|
||||
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
|
||||
min_lr = lr
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
for group in trainable_params:
|
||||
min_lr = min(min_lr, group.get("lr", lr))
|
||||
|
||||
if min_lr <= 0.1:
|
||||
print(
|
||||
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
if "relative_step" not in optimizer_kwargs:
|
||||
optimizer_kwargs["relative_step"] = True # default
|
||||
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
||||
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
||||
optimizer_kwargs["relative_step"] = True
|
||||
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
||||
|
||||
if optimizer_kwargs["relative_step"]:
|
||||
print(f"relative_step is true / relative_stepがtrueです")
|
||||
if lr != 0.0:
|
||||
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
||||
args.learning_rate = None
|
||||
|
||||
# trainable_paramsがgroupだった時の処理:lrを削除する
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
has_group_lr = False
|
||||
for group in trainable_params:
|
||||
p = group.pop("lr", None)
|
||||
has_group_lr = has_group_lr or (p is not None)
|
||||
|
||||
if has_group_lr:
|
||||
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
||||
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
||||
args.unet_lr = None
|
||||
args.text_encoder_lr = None
|
||||
|
||||
if args.lr_scheduler != "adafactor":
|
||||
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
||||
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
||||
|
||||
lr = None
|
||||
else:
|
||||
if args.max_grad_norm != 0.0:
|
||||
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
||||
if args.lr_scheduler != "constant_with_warmup":
|
||||
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
||||
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
||||
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
||||
|
||||
optimizer_class = transformers.optimization.Adafactor
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "AdamW".lower():
|
||||
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
else:
|
||||
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
if "." not in optimizer_type:
|
||||
optimizer_module = torch.optim
|
||||
else:
|
||||
values = optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
optimizer_type = values[-1]
|
||||
|
||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||
|
||||
return optimizer_name, optimizer
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
@@ -1627,6 +1724,12 @@ def get_scheduler_fix(
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
if name.startswith("adafactor"):
|
||||
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
||||
initial_lr = float(name.split(':')[1])
|
||||
# print("adafactor scheduler init lr", initial_lr)
|
||||
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
@@ -1744,13 +1847,19 @@ def prepare_dtype(args: argparse.Namespace):
|
||||
|
||||
|
||||
def load_target_model(args: argparse.Namespace, weight_dtype):
|
||||
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
if load_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
||||
else:
|
||||
print("load Diffusers pretrained models")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
||||
except EnvironmentError as ex:
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
Reference in New Issue
Block a user