mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
2 Commits
80af9d50e9
...
a10ca40d37
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a10ca40d37 | ||
|
|
6f7498f959 |
@@ -169,11 +169,21 @@ def sample_image_inference(
|
|||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
if accelerator.device.type == "cuda":
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
elif accelerator.device.type == "xpu":
|
||||||
|
torch.xpu.manual_seed(seed)
|
||||||
|
elif accelerator.device.type == "mps":
|
||||||
|
torch.mps.manual_seed(seed)
|
||||||
else:
|
else:
|
||||||
# True random sample image generation
|
# True random sample image generation
|
||||||
torch.seed()
|
torch.seed()
|
||||||
torch.cuda.seed()
|
if accelerator.device.type == "cuda":
|
||||||
|
torch.cuda.seed()
|
||||||
|
elif accelerator.device.type == "xpu":
|
||||||
|
torch.xpu.seed()
|
||||||
|
elif accelerator.device.type == "mps":
|
||||||
|
torch.mps.seed()
|
||||||
|
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
negative_prompt = ""
|
negative_prompt = ""
|
||||||
@@ -474,6 +484,29 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
|
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
|
||||||
assert bsz > 0, "Batch size not large enough"
|
assert bsz > 0, "Batch size not large enough"
|
||||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||||
|
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||||
|
max_timestep = num_timesteps if args.max_timestep is None else args.max_timestep
|
||||||
|
if min_timestep > max_timestep:
|
||||||
|
min_timestep, max_timestep = max_timestep, min_timestep
|
||||||
|
|
||||||
|
if min_timestep == max_timestep:
|
||||||
|
# Deterministic timesteps (used by validation) need fully fixed noise.
|
||||||
|
timestep_value = float(max_timestep)
|
||||||
|
timesteps = torch.full((bsz,), timestep_value, device=device, dtype=dtype)
|
||||||
|
sigma_value = timestep_value / num_timesteps
|
||||||
|
sigmas = torch.full((bsz,), sigma_value, device=device, dtype=dtype)
|
||||||
|
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||||
|
if args.ip_noise_gamma:
|
||||||
|
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||||
|
if args.ip_noise_gamma_random_strength:
|
||||||
|
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||||
|
else:
|
||||||
|
ip_noise_gamma = args.ip_noise_gamma
|
||||||
|
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||||
|
else:
|
||||||
|
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||||
|
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||||
|
|
||||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||||
# Simple random sigma-based noise sampling
|
# Simple random sigma-based noise sampling
|
||||||
if args.timestep_sampling == "sigmoid":
|
if args.timestep_sampling == "sigmoid":
|
||||||
|
|||||||
Reference in New Issue
Block a user