Add New lr scheduler (#1393)

* add new lr scheduler

* fix bugs and use num_cycles / 2

* Update requirements.txt

* add num_cycles for min lr

* keep PIECEWISE_CONSTANT

* allow use float with warmup or decay ratio.

* Update train_util.py
This commit is contained in:
青龍聖者@bdsqlsz
2024-09-11 20:25:45 +08:00
committed by GitHub
parent 62ec3e6424
commit fd68703f37
2 changed files with 75 additions and 11 deletions

View File

@@ -42,7 +42,8 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
StableDiffusionPipeline,
DDPMScheduler,
@@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser):
def int_or_float(value):
if value.endswith('%'):
try:
return float(value[:-1]) / 100.0
except ValueError:
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
try:
float_value = float(value)
if float_value >= 1:
return int(value)
return float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
parser.add_argument(
"--optimizer_type",
type=str,
@@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
type=int_or_float,
default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0",
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数デフォルト0",
)
parser.add_argument(
"--lr_decay_steps",
type=int_or_float,
default=0,
help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps",
)
parser.add_argument(
"--lr_scheduler_num_cycles",
@@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
)
parser.add_argument(
"--lr_scheduler_timescale",
type=int,
default=None,
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`",
)
parser.add_argument(
"--lr_scheduler_min_lr_ratio",
type=float,
default=None,
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name.
"""
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
timescale = args.lr_scheduler_timescale
min_lr_ratio = args.lr_scheduler_min_lr_ratio
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
@@ -4332,13 +4369,13 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
# logger.info(f"adafactor scheduler init lr {initial_lr}")
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
name = SchedulerType(name) or DiffusersSchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
if name == SchedulerType.PIECEWISE_CONSTANT:
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
# All other schedulers require `num_warmup_steps`
@@ -4348,6 +4385,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
@@ -4366,7 +4406,31 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
if name == SchedulerType.COSINE_WITH_MIN_LR:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles / 2,
min_lr_rate=min_lr_ratio,
**lr_scheduler_kwargs,
)
# All other schedulers require `num_decay_steps`
if num_decay_steps is None:
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
num_cycles=num_cycles / 2,
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
**lr_scheduler_kwargs,
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):

View File

@@ -1,5 +1,5 @@
accelerate==0.25.0
transformers==4.36.2
accelerate==0.30.0
transformers==4.41.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
@@ -16,7 +16,7 @@ altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
huggingface-hub==0.23.3
# for Image utils
imagesize==1.4.1
# for BLIP captioning