diff --git a/library/leco_train_util.py b/library/leco_train_util.py index eea3d190..5e95c163 100644 --- a/library/leco_train_util.py +++ b/library/leco_train_util.py @@ -365,7 +365,9 @@ def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor: if noise_offset is None: return latents - return latents + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu") + noise = noise.to(dtype=latents.dtype, device=latents.device) + return latents + noise_offset * noise def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor: