Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-09-11 21:46:21 +09:00
2 changed files with 101 additions and 8 deletions

View File

@@ -44,7 +44,11 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers.optimization import (
SchedulerType as DiffusersSchedulerType,
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
)
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
StableDiffusionPipeline,
DDPMScheduler,
@@ -3250,6 +3254,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser):
def int_or_float(value):
if value.endswith("%"):
try:
return float(value[:-1]) / 100.0
except ValueError:
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
try:
float_value = float(value)
if float_value >= 1:
return int(value)
return float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
parser.add_argument(
"--optimizer_type",
type=str,
@@ -3302,9 +3320,17 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
type=int_or_float,
default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0",
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
" / 学習率のスケジューラをウォームアップするステップ数デフォルト0、または学習ステップの比率1未満のfloat値の場合",
)
parser.add_argument(
"--lr_decay_steps",
type=int_or_float,
default=0,
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
" / 学習率のスケジューラを減衰させるステップ数デフォルト0、または学習ステップの比率1未満のfloat値の場合",
)
parser.add_argument(
"--lr_scheduler_num_cycles",
@@ -3324,6 +3350,21 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
)
parser.add_argument(
"--lr_scheduler_timescale",
type=int,
default=None,
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
" / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
,
)
parser.add_argument(
"--lr_scheduler_min_lr_ratio",
type=float,
default=None,
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
" / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -4571,10 +4612,18 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name.
"""
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_warmup_steps: Optional[int] = (
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
)
num_decay_steps: Optional[int] = (
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
)
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
timescale = args.lr_scheduler_timescale
min_lr_ratio = args.lr_scheduler_min_lr_ratio
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
@@ -4610,15 +4659,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
# logger.info(f"adafactor scheduler init lr {initial_lr}")
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
name = DiffusersSchedulerType(name)
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
@@ -4626,6 +4677,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
@@ -4644,7 +4698,37 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
if name == SchedulerType.COSINE_WITH_MIN_LR:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles / 2,
min_lr_rate=min_lr_ratio,
**lr_scheduler_kwargs,
)
# All other schedulers require `num_decay_steps`
if num_decay_steps is None:
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
num_cycles=num_cycles / 2,
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
**lr_scheduler_kwargs,
)
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_decay_steps=num_decay_steps,
**lr_scheduler_kwargs,
)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):