mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Update README and clean-up the code for SD3 timesteps
This commit is contained in:
@@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
secondary_separator: {subset.secondary_separator}
|
||||
enable_wildcard: {subset.enable_wildcard}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
|
||||
@@ -871,7 +871,7 @@ class MMDiT(nn.Module):
|
||||
# remove pos_embed to free up memory up to 0.4 GB
|
||||
self.pos_embed = None
|
||||
|
||||
# remove duplcates and sort latent sizes in ascending order
|
||||
# remove duplicates and sort latent sizes in ascending order
|
||||
latent_sizes = list(set(latent_sizes))
|
||||
latent_sizes = sorted(latent_sizes)
|
||||
|
||||
|
||||
@@ -253,7 +253,7 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
||||
)
|
||||
|
||||
# Dependencies of Diffusers noise sampler has been removed for clearity.
|
||||
# Dependencies of Diffusers noise sampler has been removed for clarity.
|
||||
parser.add_argument(
|
||||
"--weighting_scheme",
|
||||
type=str,
|
||||
@@ -285,7 +285,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
default=1.0,
|
||||
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
|
||||
)
|
||||
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
@@ -956,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# endregion
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
@@ -977,13 +979,12 @@ def get_noisy_model_input_and_timesteps(
|
||||
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
|
||||
u = (u * shift) / (1 + (shift - 1) * u)
|
||||
|
||||
indices = (u * (t_max-t_min) + t_min).long()
|
||||
indices = (u * (t_max - t_min) + t_min).long()
|
||||
timesteps = indices.to(device=device, dtype=dtype)
|
||||
|
||||
# sigmas according to flowmatching
|
||||
sigmas = timesteps / 1000
|
||||
sigmas = sigmas.view(-1,1,1,1)
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
return noisy_model_input, timesteps, sigmas
|
||||
|
||||
|
||||
Reference in New Issue
Block a user