mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Merge 6f7498f959 into 1dae34b0af
This commit is contained in:
@@ -169,11 +169,21 @@ def sample_image_inference(
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if accelerator.device.type == "cuda":
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif accelerator.device.type == "xpu":
|
||||
torch.xpu.manual_seed(seed)
|
||||
elif accelerator.device.type == "mps":
|
||||
torch.mps.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
if accelerator.device.type == "cuda":
|
||||
torch.cuda.seed()
|
||||
elif accelerator.device.type == "xpu":
|
||||
torch.xpu.seed()
|
||||
elif accelerator.device.type == "mps":
|
||||
torch.mps.seed()
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
@@ -474,6 +484,29 @@ def get_noisy_model_input_and_timesteps(
|
||||
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = num_timesteps if args.max_timestep is None else args.max_timestep
|
||||
if min_timestep > max_timestep:
|
||||
min_timestep, max_timestep = max_timestep, min_timestep
|
||||
|
||||
if min_timestep == max_timestep:
|
||||
# Deterministic timesteps (used by validation) need fully fixed noise.
|
||||
timestep_value = float(max_timestep)
|
||||
timesteps = torch.full((bsz,), timestep_value, device=device, dtype=dtype)
|
||||
sigma_value = timestep_value / num_timesteps
|
||||
sigmas = torch.full((bsz,), sigma_value, device=device, dtype=dtype)
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
|
||||
Reference in New Issue
Block a user