monkeypatch updated get_scheduler for diffusers

enables use of "num_cycles" and "power" for cosine_with_restarts and polynomial learning rate schedulers
This commit is contained in:
michaelgzhang
2023-01-27 16:42:11 -06:00
parent 67e698af67
commit 0fef7b4684

View File

@@ -35,6 +35,75 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
return logs
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
from typing import Optional, Union
from torch.optim import Optimizer
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diffusers.optimization.get_scheduler = get_scheduler_fix
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -157,7 +226,9 @@ def train(args):
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles = args.num_cycles, power = args.power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -460,6 +531,10 @@ if __name__ == '__main__':
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
parser.add_argument("--training_comment", type=str, default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
parser.add_argument("--num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts")
parser.add_argument("--power", type=float, default=1,
help="Polynomial power for polynomial scheduler")
args = parser.parse_args()
train(args)