From bfdbf04059f44b52eca6fb1295e25d18d798ad55 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 19:06:33 +0900 Subject: [PATCH] =?UTF-8?q?fix:=20apply=5Fnoise=5Foffset=20=E3=81=AE=20dty?= =?UTF-8?q?pe=20=E4=B8=8D=E4=B8=80=E8=87=B4=E3=82=92=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch.randn のデフォルト float32 により latents が暗黙的にアップキャストされる問題を修正。 float32/CPU で生成後に latents の dtype/device へ変換する安全なパターンを採用。 Co-Authored-By: Claude Opus 4.6 --- library/leco_train_util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/library/leco_train_util.py b/library/leco_train_util.py index eea3d190..5e95c163 100644 --- a/library/leco_train_util.py +++ b/library/leco_train_util.py @@ -365,7 +365,9 @@ def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor: if noise_offset is None: return latents - return latents + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu") + noise = noise.to(dtype=latents.dtype, device=latents.device) + return latents + noise_offset * noise def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor: