mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add New lr scheduler (#1393)
* add new lr scheduler * fix bugs and use num_cycles / 2 * Update requirements.txt * add num_cycles for min lr * keep PIECEWISE_CONSTANT * allow use float with warmup or decay ratio. * Update train_util.py
This commit is contained in:
@@ -42,7 +42,8 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION
|
||||
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
DDPMScheduler,
|
||||
@@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
def int_or_float(value):
|
||||
if value.endswith('%'):
|
||||
try:
|
||||
return float(value[:-1]) / 100.0
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
|
||||
try:
|
||||
float_value = float(value)
|
||||
if float_value >= 1:
|
||||
return int(value)
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
|
||||
|
||||
parser.add_argument(
|
||||
"--optimizer_type",
|
||||
type=str,
|
||||
@@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
type=int_or_float,
|
||||
default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
|
||||
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_decay_steps",
|
||||
type=int_or_float,
|
||||
default=0,
|
||||
help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_num_cycles",
|
||||
@@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
|
||||
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_timescale",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_min_lr_ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler",
|
||||
)
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
@@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
Unified API to get any scheduler from its name.
|
||||
"""
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
|
||||
num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
|
||||
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
|
||||
num_cycles = args.lr_scheduler_num_cycles
|
||||
power = args.lr_scheduler_power
|
||||
timescale = args.lr_scheduler_timescale
|
||||
min_lr_ratio = args.lr_scheduler_min_lr_ratio
|
||||
|
||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||
@@ -4332,13 +4369,13 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
# logger.info(f"adafactor scheduler init lr {initial_lr}")
|
||||
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
name = SchedulerType(name) or DiffusersSchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||
|
||||
if name == SchedulerType.PIECEWISE_CONSTANT:
|
||||
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT:
|
||||
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
@@ -4348,6 +4385,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
||||
|
||||
if name == SchedulerType.INVERSE_SQRT:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
@@ -4366,7 +4406,31 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
|
||||
if name == SchedulerType.COSINE_WITH_MIN_LR:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles / 2,
|
||||
min_lr_rate=min_lr_ratio,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
# All other schedulers require `num_decay_steps`
|
||||
if num_decay_steps is None:
|
||||
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
||||
if name == SchedulerType.WARMUP_STABLE_DECAY:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_stable_steps=num_stable_steps,
|
||||
num_decay_steps=num_decay_steps,
|
||||
num_cycles=num_cycles / 2,
|
||||
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
accelerate==0.25.0
|
||||
transformers==4.36.2
|
||||
accelerate==0.30.0
|
||||
transformers==4.41.2
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
# albumentations==1.3.0
|
||||
@@ -16,7 +16,7 @@ altair==4.2.2
|
||||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.20.1
|
||||
huggingface-hub==0.23.3
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
# for BLIP captioning
|
||||
|
||||
Reference in New Issue
Block a user