diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index c96e4bb6..3247ecbe 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -169,11 +169,21 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if accelerator.device.type == "cuda": + torch.cuda.manual_seed(seed) + elif accelerator.device.type == "xpu": + torch.xpu.manual_seed(seed) + elif accelerator.device.type == "mps": + torch.mps.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if accelerator.device.type == "cuda": + torch.cuda.seed() + elif accelerator.device.type == "xpu": + torch.xpu.seed() + elif accelerator.device.type == "mps": + torch.mps.seed() if negative_prompt is None: negative_prompt = "" @@ -474,6 +484,29 @@ def get_noisy_model_input_and_timesteps( bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1] assert bsz > 0, "Batch size not large enough" num_timesteps = noise_scheduler.config.num_train_timesteps + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = num_timesteps if args.max_timestep is None else args.max_timestep + if min_timestep > max_timestep: + min_timestep, max_timestep = max_timestep, min_timestep + + if min_timestep == max_timestep: + # Deterministic timesteps (used by validation) need fully fixed noise. + timestep_value = float(max_timestep) + timesteps = torch.full((bsz,), timestep_value, device=device, dtype=dtype) + sigma_value = timestep_value / num_timesteps + sigmas = torch.full((bsz,), sigma_value, device=device, dtype=dtype) + sigmas = sigmas.view(-1, 1, 1, 1) + if args.ip_noise_gamma: + xi = torch.randn_like(latents, device=latents.device, dtype=dtype) + if args.ip_noise_gamma_random_strength: + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma + else: + ip_noise_gamma = args.ip_noise_gamma + noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + else: + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid":