mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add comment
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from torch.optim import Optimizer
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
@@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
from typing import Optional, Union
|
||||
from torch.optim import Optimizer
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
@@ -52,53 +52,53 @@ def get_scheduler_fix(
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -135,7 +135,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset)
|
||||
return
|
||||
if len(train_dataset) == 0:
|
||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
||||
return
|
||||
|
||||
# acceleratorを準備する
|
||||
@@ -224,7 +224,7 @@ def train(args):
|
||||
# lr schedulerを用意する
|
||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
lr_scheduler = get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user