mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Simplify Timestep weighting
* Remove diffusers dependency in ts & sigma calc * support Shift setting * Add uniform distribution * Default to Uniform distribution and shift 1
This commit is contained in:
@@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
||||||
)
|
)
|
||||||
|
|
||||||
# copy from Diffusers
|
# Dependencies of Diffusers noise sampler has been removed for clearity.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--weighting_scheme",
|
"--weighting_scheme",
|
||||||
type=str,
|
type=str,
|
||||||
default="logit_normal",
|
default="uniform",
|
||||||
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
|
||||||
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
|
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -279,8 +279,13 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=1.29,
|
default=1.29,
|
||||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
|
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--training_shift",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
|
||||||
|
)
|
||||||
|
|
||||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
@@ -965,14 +970,20 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
logit_std=args.logit_std,
|
logit_std=args.logit_std,
|
||||||
mode_scale=args.mode_scale,
|
mode_scale=args.mode_scale,
|
||||||
)
|
)
|
||||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
t_min = args.min_timestep if args.min_timestep is not None else 0
|
||||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
t_max = args.max_timestep if args.max_timestep is not None else 1000
|
||||||
|
shift = args.training_shift
|
||||||
|
|
||||||
# Add noise according to flow matching.
|
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
|
||||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
u = (u * shift) / (1 + (shift - 1) * u)
|
||||||
|
|
||||||
|
indices = (u * (t_max-t_min) + t_min).long()
|
||||||
|
timesteps = indices.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# sigmas according to dlowmatching
|
||||||
|
sigmas = timesteps / 1000
|
||||||
|
sigmas = sigmas.view(-1,1,1,1)
|
||||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||||
|
|
||||||
return noisy_model_input, timesteps, sigmas
|
return noisy_model_input, timesteps, sigmas
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
|
||||||
|
|||||||
Reference in New Issue
Block a user