mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
monkeypatch updated get_scheduler for diffusers
enables use of "num_cycles" and "power" for cosine_with_restarts and polynomial learning rate schedulers
This commit is contained in:
@@ -35,6 +35,75 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
|||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
|
||||||
|
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||||
|
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||||
|
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||||
|
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||||
|
|
||||||
|
def get_scheduler_fix(
|
||||||
|
name: Union[str, SchedulerType],
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: Optional[int] = None,
|
||||||
|
num_training_steps: Optional[int] = None,
|
||||||
|
num_cycles: int = 1,
|
||||||
|
power: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Unified API to get any scheduler from its name.
|
||||||
|
Args:
|
||||||
|
name (`str` or `SchedulerType`):
|
||||||
|
The name of the scheduler to use.
|
||||||
|
optimizer (`torch.optim.Optimizer`):
|
||||||
|
The optimizer that will be used during training.
|
||||||
|
num_warmup_steps (`int`, *optional*):
|
||||||
|
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
num_training_steps (`int``, *optional*):
|
||||||
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
num_cycles (`int`, *optional*):
|
||||||
|
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||||
|
power (`float`, *optional*, defaults to 1.0):
|
||||||
|
Power factor. See `POLYNOMIAL` scheduler
|
||||||
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
"""
|
||||||
|
name = SchedulerType(name)
|
||||||
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
if name == SchedulerType.CONSTANT:
|
||||||
|
return schedule_func(optimizer)
|
||||||
|
|
||||||
|
# All other schedulers require `num_warmup_steps`
|
||||||
|
if num_warmup_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||||
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||||
|
|
||||||
|
# All other schedulers require `num_training_steps`
|
||||||
|
if num_training_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||||
|
return schedule_func(
|
||||||
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||||
|
)
|
||||||
|
|
||||||
|
if name == SchedulerType.POLYNOMIAL:
|
||||||
|
return schedule_func(
|
||||||
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||||
|
)
|
||||||
|
|
||||||
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||||
|
|
||||||
|
diffusers.optimization.get_scheduler = get_scheduler_fix
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
training_started_at = time.time()
|
training_started_at = time.time()
|
||||||
@@ -157,7 +226,9 @@ def train(args):
|
|||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||||
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
|
num_cycles = args.num_cycles, power = args.power)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -460,6 +531,10 @@ if __name__ == '__main__':
|
|||||||
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
||||||
parser.add_argument("--training_comment", type=str, default=None,
|
parser.add_argument("--training_comment", type=str, default=None,
|
||||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||||
|
parser.add_argument("--num_cycles", type=int, default=1,
|
||||||
|
help="Number of restarts for cosine scheduler with restarts")
|
||||||
|
parser.add_argument("--power", type=float, default=1,
|
||||||
|
help="Polynomial power for polynomial scheduler")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user