Add Adafactor optimzier

This commit is contained in:
Kohya S
2023-02-22 21:09:47 +09:00
parent 663aad2b0d
commit 9ab964d0b8
5 changed files with 181 additions and 80 deletions

View File

@@ -149,7 +149,7 @@ def train(args):
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") print("prepare optimizer, data loader etc.")
_, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@@ -163,8 +163,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix( lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
@@ -268,11 +267,11 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = [] params_to_clip = []
for m in training_models: for m in training_models:
params_to_clip.extend(m.parameters()) params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
@@ -285,8 +284,8 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)

View File

@@ -1,6 +1,7 @@
# common functions for training # common functions for training
import argparse import argparse
import importlib
import json import json
import shutil import shutil
import time import time
@@ -21,6 +22,7 @@ import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
import transformers
import diffusers import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import DDPMScheduler, StableDiffusionPipeline from diffusers import DDPMScheduler, StableDiffusionPipeline
@@ -1371,28 +1373,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW", parser.add_argument("--optimizer_type", type=str, default="AdamW",
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation") help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
# backward compatibility
parser.add_argument("--use_8bit_adam", action="store_true", parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要") help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true", parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)") help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う dadaptation のインストールが必要)")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--optimizer_momentum", type=float, default=0.9, parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Momentum value for optimizers for SGD optimizers") help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
help="Weight decay for optimizers") parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
parser.add_argument("--optimizer_beta1", type=float, default=0.9, help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"")
help="beta1 parameter for Adam optimizers")
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
help="beta2 parameter for Adam optimizers")
parser.add_argument("--lr_scheduler", type=str, default="constant", parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
parser.add_argument("--lr_warmup_steps", type=int, default=0, parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0") help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
parser.add_argument("--lr_scheduler_power", type=float, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -1525,18 +1528,37 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
def get_optimizer(args, trainable_params): def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation" # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
optimizer_type = args.optimizer_type optimizer_type = args.optimizer_type
if args.use_8bit_adam: if args.use_8bit_adam:
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます")
optimizer_type = "AdamW8bit" optimizer_type = "AdamW8bit"
elif args.use_lion_optimizer: elif args.use_lion_optimizer:
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます")
optimizer_type = "Lion" optimizer_type = "Lion"
optimizer_type = optimizer_type.lower() optimizer_type = optimizer_type.lower()
betas = (args.optimizer_beta1, args.optimizer_beta2) # 引数を分解するboolとfloat、tupleのみ対応
weight_decay = args.optimizer_weight_decay optimizer_kwargs = {}
momentum = args.optimizer_momentum if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args:
key, value = arg.split('=')
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = (value[i].lower() == "true")
else:
value[i] = float(value[i])
if len(value) == 1:
value = value[0]
else:
value = tuple(value)
optimizer_kwargs[key] = value
print("optkwargs:", optimizer_kwargs)
lr = args.learning_rate lr = args.learning_rate
if optimizer_type == "AdamW8bit".lower(): if optimizer_type == "AdamW8bit".lower():
@@ -1544,53 +1566,128 @@ def get_optimizer(args, trainable_params):
import bitsandbytes as bnb import bitsandbytes as bnb
except ImportError: except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
optimizer_class = bnb.optim.AdamW8bit optimizer_class = bnb.optim.AdamW8bit
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov8bit".lower(): elif optimizer_type == "SGDNesterov8bit".lower():
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
except ImportError: except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}") print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
optimizer_kwargs["momentum"] = 0.9
optimizer_class = bnb.optim.SGD8bit optimizer_class = bnb.optim.SGD8bit
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type == "Lion".lower(): elif optimizer_type == "Lion".lower():
try: try:
import lion_pytorch import lion_pytorch
except ImportError: except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}") print(f"use Lion optimizer | {optimizer_kwargs}")
optimizer_class = lion_pytorch.Lion optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov".lower(): elif optimizer_type == "SGDNesterov".lower():
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}") print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
optimizer_kwargs["momentum"] = 0.9
optimizer_class = torch.optim.SGD optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type == "DAdaptation".lower(): elif optimizer_type == "DAdaptation".lower():
try: try:
import dadaptation import dadaptation
except ImportError: except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです") raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}") print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
optimizer_class = dadaptation.DAdaptAdam
if lr <= 0.1: min_lr = lr
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}') if type(trainable_params) == list and type(trainable_params[0]) == dict:
for group in trainable_params:
min_lr = min(min_lr, group.get("lr", lr))
if min_lr <= 0.1:
print(
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
print('recommend option: lr=1.0 / 推奨は1.0です') print('recommend option: lr=1.0 / 推奨は1.0です')
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
optimizer_class = dadaptation.DAdaptAdam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "Adafactor".lower():
# 引数を確認して適宜補正する
if "relative_step" not in optimizer_kwargs:
optimizer_kwargs["relative_step"] = True # default
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
optimizer_kwargs["relative_step"] = True
print(f"use Adafactor optimizer | {optimizer_kwargs}")
if optimizer_kwargs["relative_step"]:
print(f"relative_step is true / relative_stepがtrueです")
if lr != 0.0:
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
args.learning_rate = None
# trainable_paramsがgroupだった時の処理lrを削除する
if type(trainable_params) == list and type(trainable_params[0]) == dict:
has_group_lr = False
for group in trainable_params:
p = group.pop("lr", None)
has_group_lr = has_group_lr or (p is not None)
if has_group_lr:
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
args.unet_lr = None
args.text_encoder_lr = None
if args.lr_scheduler != "adafactor":
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
lr = None
else:
if args.max_grad_norm != 0.0:
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
if args.lr_scheduler != "constant_with_warmup":
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
print(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
else: else:
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") # 任意のoptimizerを使う
optimizer_class = torch.optim.AdamW optimizer_type = args.optimizer_type # lowerでないやつ微妙
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) print(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
return optimizer_name, optimizer return optimizer_name, optimizer_args, optimizer
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
@@ -1627,6 +1724,12 @@ def get_scheduler_fix(
last_epoch (`int`, *optional*, defaults to -1): last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training. The index of the last epoch when resuming training.
""" """
if name.startswith("adafactor"):
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(':')[1])
# print("adafactor scheduler init lr", initial_lr)
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
@@ -1744,13 +1847,19 @@ def prepare_dtype(args: argparse.Namespace):
def load_target_model(args: argparse.Namespace, weight_dtype): def load_target_model(args: argparse.Namespace, weight_dtype):
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format: if load_stable_diffusion_format:
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
else: else:
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
except EnvironmentError as ex:
print(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
text_encoder = pipe.text_encoder text_encoder = pipe.text_encoder
vae = pipe.vae vae = pipe.vae
unet = pipe.unet unet = pipe.unet

View File

@@ -120,7 +120,7 @@ def train(args):
else: else:
trainable_params = unet.parameters() trainable_params = unet.parameters()
_, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@@ -137,8 +137,7 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix( lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps, num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
@@ -263,12 +262,12 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder: if train_text_encoder:
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
else: else:
params_to_clip = unet.parameters() params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
@@ -281,8 +280,8 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)

View File

@@ -28,14 +28,14 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs = {"loss/current": current_loss, "loss/average": avr_loss} logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only: if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0] logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
elif args.network_train_text_encoder_only: elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
else: else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
return logs return logs
@@ -147,7 +147,7 @@ def train(args):
print("prepare optimizer, data loader etc.") print("prepare optimizer, data loader etc.")
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@@ -161,8 +161,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix( lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
@@ -287,7 +286,7 @@ def train(args):
"ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment, # will not be updated after training "ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "")
} }
# uncomment if another network is added # uncomment if another network is added
@@ -380,9 +379,9 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params() params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
@@ -478,10 +477,6 @@ if __name__ == '__main__':
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
parser.add_argument("--lr_scheduler_power", type=float, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
parser.add_argument("--network_weights", type=str, default=None, parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み") help="pretrained weights for network / 学習するネットワークの初期重み")

View File

@@ -199,7 +199,7 @@ def train(args):
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters() trainable_params = text_encoder.get_input_embeddings().parameters()
_, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@@ -213,8 +213,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix( lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
@@ -338,9 +337,9 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters() params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
@@ -357,8 +356,8 @@ def train(args):
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)