From 6f0e235f2cb9a9829bc12280c29e12c0ae66c88f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 Jul 2024 08:00:45 +0900 Subject: [PATCH] Fix shift value in SD3 inference. --- sd3_minimal_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index 7f5f28ce..ffa0d46d 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -64,7 +64,7 @@ def do_sample( device: str, ): if initial_latent is None: - # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 + # latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out latent = torch.zeros(1, 16, height // 8, width // 8, device=device) else: latent = initial_latent @@ -73,7 +73,7 @@ def do_sample( noise = get_noise(seed, latent).to(device) - model_sampling = sd3_utils.ModelSamplingDiscreteFlow() + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 sigmas = get_sigmas(model_sampling, steps).to(device) # sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i