Perturbed noise should be separate of input noise

This commit is contained in:
rockerBOO
2025-03-19 00:25:51 -04:00
parent b81bcd0b01
commit 7197266703
2 changed files with 12 additions and 10 deletions

View File

@@ -350,15 +350,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents)
else:
noise = noise + args.ip_noise_gamma * torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps

View File

@@ -410,11 +410,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
args, noise_scheduler, latents: torch.Tensor, input_noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
noise = input_noise.detach().clone() + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents)
else:
noise = input_noise.detach().clone() + args.ip_noise_gamma * torch.randn_like(latents)
else:
noise = input_noise
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":