Fix shift value in SD3 inference.

This commit is contained in:
Kohya S
2024-07-11 08:00:45 +09:00
parent 3d402927ef
commit 6f0e235f2c

View File

@@ -64,7 +64,7 @@ def do_sample(
device: str,
):
if initial_latent is None:
# latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609
# latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
else:
latent = initial_latent
@@ -73,7 +73,7 @@ def do_sample(
noise = get_noise(seed, latent).to(device)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow()
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_sigmas(model_sampling, steps).to(device)
# sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i