refactor get_scheduler etc.

This commit is contained in:
Kohya S
2023-02-20 22:47:43 +09:00
parent 12d30afb39
commit 663aad2b0d
5 changed files with 119 additions and 100 deletions

View File

@@ -1,8 +1,5 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Optional, Union
import importlib
import argparse
import gc
@@ -26,6 +23,7 @@ def collate_fn(examples):
return examples[0]
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -37,75 +35,12 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
if args.use_dadaptation_optimizer: # tracking d*lr value of unet.
if args.optimizer_type == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
return logs
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -164,7 +99,7 @@ def train(args):
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -226,8 +161,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix(
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)