Add New lr scheduler (#1393)

* add new lr scheduler

* fix bugs and use num_cycles / 2

* Update requirements.txt

* add num_cycles for min lr

* keep PIECEWISE_CONSTANT

* allow use float with warmup or decay ratio.

* Update train_util.py
This commit is contained in:
青龍聖者@bdsqlsz
2024-09-11 20:25:45 +08:00
committed by GitHub
parent 62ec3e6424
commit fd68703f37
2 changed files with 75 additions and 11 deletions

View File

@@ -42,7 +42,8 @@ from torch.optim import Optimizer
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import ( from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
DDPMScheduler, DDPMScheduler,
@@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser):
def int_or_float(value):
if value.endswith('%'):
try:
return float(value[:-1]) / 100.0
except ValueError:
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
try:
float_value = float(value)
if float_value >= 1:
return int(value)
return float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
parser.add_argument( parser.add_argument(
"--optimizer_type", "--optimizer_type",
type=str, type=str,
@@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
) )
parser.add_argument( parser.add_argument(
"--lr_warmup_steps", "--lr_warmup_steps",
type=int, type=int_or_float,
default=0, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0", help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数デフォルト0",
)
parser.add_argument(
"--lr_decay_steps",
type=int_or_float,
default=0,
help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps",
) )
parser.add_argument( parser.add_argument(
"--lr_scheduler_num_cycles", "--lr_scheduler_num_cycles",
@@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
) )
parser.add_argument(
"--lr_scheduler_timescale",
type=int,
default=None,
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`",
)
parser.add_argument(
"--lr_scheduler_min_lr_ratio",
type=float,
default=None,
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
""" """
name = args.lr_scheduler name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
num_cycles = args.lr_scheduler_num_cycles num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power power = args.lr_scheduler_power
timescale = args.lr_scheduler_timescale
min_lr_ratio = args.lr_scheduler_min_lr_ratio
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
@@ -4332,13 +4369,13 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
# logger.info(f"adafactor scheduler init lr {initial_lr}") # logger.info(f"adafactor scheduler init lr {initial_lr}")
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
name = SchedulerType(name) name = SchedulerType(name) or DiffusersSchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
if name == SchedulerType.PIECEWISE_CONSTANT: if name == DiffusersSchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
@@ -4348,6 +4385,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
if name == SchedulerType.CONSTANT_WITH_WARMUP: if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
# All other schedulers require `num_training_steps` # All other schedulers require `num_training_steps`
if num_training_steps is None: if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
@@ -4366,7 +4406,31 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
) )
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) if name == SchedulerType.COSINE_WITH_MIN_LR:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles / 2,
min_lr_rate=min_lr_ratio,
**lr_scheduler_kwargs,
)
# All other schedulers require `num_decay_steps`
if num_decay_steps is None:
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
num_cycles=num_cycles / 2,
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
**lr_scheduler_kwargs,
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):

View File

@@ -1,5 +1,5 @@
accelerate==0.25.0 accelerate==0.30.0
transformers==4.36.2 transformers==4.41.2
diffusers[torch]==0.25.0 diffusers[torch]==0.25.0
ftfy==6.1.1 ftfy==6.1.1
# albumentations==1.3.0 # albumentations==1.3.0
@@ -16,7 +16,7 @@ altair==4.2.2
easygui==0.98.3 easygui==0.98.3
toml==0.10.2 toml==0.10.2
voluptuous==0.13.1 voluptuous==0.13.1
huggingface-hub==0.20.1 huggingface-hub==0.23.3
# for Image utils # for Image utils
imagesize==1.4.1 imagesize==1.4.1
# for BLIP captioning # for BLIP captioning