diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index e3c649f7..b04b86fb 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -316,6 +316,8 @@ def do_sample( # noise = get_noise(seed, latent).to(device) if seed is not None: generator = torch.manual_seed(seed) + else: + generator = None noise = ( torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") .to(latent.dtype)