This commit is contained in:
Matthew Turnshek
2026-03-31 12:55:53 +00:00
committed by GitHub

View File

@@ -169,11 +169,21 @@ def sample_image_inference(
if seed is not None:
torch.manual_seed(seed)
if accelerator.device.type == "cuda":
torch.cuda.manual_seed(seed)
elif accelerator.device.type == "xpu":
torch.xpu.manual_seed(seed)
elif accelerator.device.type == "mps":
torch.mps.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
if accelerator.device.type == "cuda":
torch.cuda.seed()
elif accelerator.device.type == "xpu":
torch.xpu.seed()
elif accelerator.device.type == "mps":
torch.mps.seed()
if negative_prompt is None:
negative_prompt = ""
@@ -474,6 +484,29 @@ def get_noisy_model_input_and_timesteps(
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = num_timesteps if args.max_timestep is None else args.max_timestep
if min_timestep > max_timestep:
min_timestep, max_timestep = max_timestep, min_timestep
if min_timestep == max_timestep:
# Deterministic timesteps (used by validation) need fully fixed noise.
timestep_value = float(max_timestep)
timesteps = torch.full((bsz,), timestep_value, device=device, dtype=dtype)
sigma_value = timestep_value / num_timesteps
sigmas = torch.full((bsz,), sigma_value, device=device, dtype=dtype)
sigmas = sigmas.view(-1, 1, 1, 1)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":