refactor get_scheduler etc.

This commit is contained in:
Kohya S
2023-02-20 22:47:43 +09:00
parent 12d30afb39
commit 663aad2b0d
5 changed files with 119 additions and 100 deletions

View File

@@ -5,6 +5,7 @@ import json
import shutil
import time
from typing import Dict, List, NamedTuple, Tuple
from typing import Optional, Union
from accelerate import Accelerator
from torch.autograd.function import Function
import glob
@@ -17,9 +18,11 @@ from io import BytesIO
from tqdm import tqdm
import torch
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
@@ -1368,12 +1371,18 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW",
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う dadaptation のインストールが必要)")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
help="Momentum value for optimizers")
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
help="Momentum value for optimizers for SGD optimizers")
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
help="Weight decay for optimizers")
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
help="beta1 parameter for Adam optimizers")
@@ -1407,12 +1416,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
# parser.add_argument("--use_8bit_adam", action="store_true",
# help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
# parser.add_argument("--use_lion_optimizer", action="store_true",
# help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う dadaptation のインストールが必要)")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
@@ -1520,14 +1523,19 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
# region utils
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaption"
def get_optimizer(args, trainable_params):
# Prepare optimizer/学習に必要なクラスを準備する
optimizer_type = args.optimizer_type.lower()
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
optimizer_type = "AdamW8bit"
elif args.use_lion_optimizer:
optimizer_type = "Lion"
optimizer_type = optimizer_type.lower()
betas = (args.optimizer_beta1, args.optimizer_beta2)
weight_decay = args.optimizer_weightdecay
weight_decay = args.optimizer_weight_decay
momentum = args.optimizer_momentum
lr = args.learning_rate
@@ -1563,17 +1571,18 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
elif optimizer_type == "dadaptation".lower():
elif optimizer_type == "DAdaptation".lower():
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print(f"use dadaptation optimizer")
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = dadaptation.DAdaptAdam
if args.learning_rate <= 0.1:
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
print('recommend option: lr=1.0')
optimizer = optimizer_class(trainable_params, lr=lr)
if lr <= 0.1:
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
print('recommend option: lr=1.0 / 推奨は1.0です')
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
else:
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = torch.optim.AdamW
@@ -1584,6 +1593,69 @@ def get_optimizer(args, trainable_params):
return optimizer_name, optimizer
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
# backward compatibility
if args.caption_extention is not None: