mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor get_scheduler etc.
This commit is contained in:
@@ -5,6 +5,7 @@ import json
|
||||
import shutil
|
||||
import time
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
from typing import Optional, Union
|
||||
from accelerate import Accelerator
|
||||
from torch.autograd.function import Function
|
||||
import glob
|
||||
@@ -17,9 +18,11 @@ from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
@@ -1368,12 +1371,18 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation")
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
|
||||
help="Momentum value for optimizers")
|
||||
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
|
||||
help="Momentum value for optimizers for SGD optimizers")
|
||||
parser.add_argument("--optimizer_weight_decay", type=float, default=0.01,
|
||||
help="Weight decay for optimizers")
|
||||
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
|
||||
help="beta1 parameter for Adam optimizers")
|
||||
@@ -1407,12 +1416,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
||||
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
||||
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
||||
# parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
# help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
# parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
# help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
# parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
# help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||
parser.add_argument("--xformers", action="store_true",
|
||||
@@ -1520,14 +1523,19 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
# region utils
|
||||
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaption"
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# Prepare optimizer/学習に必要なクラスを準備する
|
||||
optimizer_type = args.optimizer_type.lower()
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaptation"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
optimizer_type = "AdamW8bit"
|
||||
elif args.use_lion_optimizer:
|
||||
optimizer_type = "Lion"
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
betas = (args.optimizer_beta1, args.optimizer_beta2)
|
||||
weight_decay = args.optimizer_weightdecay
|
||||
weight_decay = args.optimizer_weight_decay
|
||||
momentum = args.optimizer_momentum
|
||||
lr = args.learning_rate
|
||||
|
||||
@@ -1563,17 +1571,18 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
|
||||
elif optimizer_type == "dadaptation".lower():
|
||||
elif optimizer_type == "DAdaptation".lower():
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use dadaptation optimizer")
|
||||
print(f"use D-Adaptation Adam optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
if args.learning_rate <= 0.1:
|
||||
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
|
||||
print('recommend option: lr=1.0')
|
||||
optimizer = optimizer_class(trainable_params, lr=lr)
|
||||
if lr <= 0.1:
|
||||
print(f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {lr}')
|
||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
else:
|
||||
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
@@ -1584,6 +1593,69 @@ def get_optimizer(args, trainable_params):
|
||||
return optimizer_name, optimizer
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
# backward compatibility
|
||||
if args.caption_extention is not None:
|
||||
|
||||
Reference in New Issue
Block a user