Merge branch 'sd3' into faster-block-swap

This commit is contained in:
Kohya S
2024-11-07 21:43:48 +09:00
8 changed files with 64 additions and 28 deletions

View File

@@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}

View File

@@ -871,7 +871,7 @@ class MMDiT(nn.Module):
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# remove duplcates and sort latent sizes in ascending order
# remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
latent_sizes = sorted(latent_sizes)

View File

@@ -253,12 +253,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# copy from Diffusers
# Dependencies of Diffusers noise sampler has been removed for clarity.
parser.add_argument(
"--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
)
parser.add_argument(
@@ -279,6 +279,12 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
)
parser.add_argument(
"--training_shift",
type=float,
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
@@ -951,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# endregion
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
# Sample a random timestep for each image
@@ -965,14 +972,19 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
t_min = args.min_timestep if args.min_timestep is not None else 0
t_max = args.max_timestep if args.max_timestep is not None else 1000
shift = args.training_shift
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)
indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)
# sigmas according to flowmatching
sigmas = timesteps / 1000
sigmas = sigmas.view(-1, 1, 1, 1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas
# endregion