mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix typo
This commit is contained in:
@@ -1643,7 +1643,7 @@ class T5LayerNorm(torch.nn.Module):
|
|||||||
# copy from transformers' T5LayerNorm
|
# copy from transformers' T5LayerNorm
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||||
# half-precision inputs is done in fp32
|
# half-precision inputs is done in fp32
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ def sample_images(*args, **kwargs):
|
|||||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy):
|
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
||||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||||
|
|
||||||
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ def train(args):
|
|||||||
file_suffix="_sd3.npz",
|
file_suffix="_sd3.npz",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
strategy = sd3_train_utils.Sd3LatensCachingStrategy(
|
strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
|
||||||
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||||
)
|
)
|
||||||
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
|
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
|
||||||
|
|||||||
Reference in New Issue
Block a user