This commit is contained in:
Kohya S
2024-07-09 20:37:00 +09:00
parent 3ea4fce5e0
commit 9dc7997803
3 changed files with 3 additions and 3 deletions

View File

@@ -1643,7 +1643,7 @@ class T5LayerNorm(torch.nn.Module):
# copy from transformers' T5LayerNorm # copy from transformers' T5LayerNorm
def forward(self, hidden_states): def forward(self, hidden_states):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32 # half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)

View File

@@ -279,7 +279,7 @@ def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy): class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:

View File

@@ -217,7 +217,7 @@ def train(args):
file_suffix="_sd3.npz", file_suffix="_sd3.npz",
) )
else: else:
strategy = sd3_train_utils.Sd3LatensCachingStrategy( strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
) )
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)