mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add IP noise gamma for Flux
This commit is contained in:
@@ -415,6 +415,16 @@ def get_noisy_model_input_and_timesteps(
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
ip_noise_gamma = 0.0
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
@@ -425,7 +435,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
@@ -435,7 +445,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
@@ -445,7 +455,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
noisy_model_input = (1 - t) * latents + t * noise + ip_noise_gamma
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -461,7 +471,8 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + ip_noise_gamma
|
||||
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
Reference in New Issue
Block a user