diff --git a/flux_train_network.py b/flux_train_network.py index def44155..d85584f5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -350,6 +350,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + noise = noise + (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) + else: + noise = noise + args.ip_noise_gamma * torch.randn_like(latents) + bsz = latents.shape[0] # get noisy model input and timesteps diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f0744747..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -415,16 +415,6 @@ def get_noisy_model_input_and_timesteps( bsz, _, h, w = latents.shape sigmas = None - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - ip_noise = (torch.rand(1, device=latents.device) * args.ip_noise_gamma) * torch.randn_like(latents) - else: - ip_noise = args.ip_noise_gamma * torch.randn_like(latents) - else: - ip_noise = torch.zeros_like(latents) - if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": @@ -435,7 +425,7 @@ def get_noisy_model_input_and_timesteps( timesteps = t * 1000.0 t = t.view(-1, 1, 1, 1) - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) @@ -445,7 +435,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling @@ -455,7 +445,7 @@ def get_noisy_model_input_and_timesteps( t = timesteps.view(-1, 1, 1, 1) timesteps = timesteps * 1000.0 - noisy_model_input = (1 - t) * latents + t * (noise + ip_noise) + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -471,8 +461,7 @@ def get_noisy_model_input_and_timesteps( # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - noisy_model_input = sigmas * (noise + ip_noise) + (1.0 - sigmas) * latents - + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas