From 4bea5826011ef3134b3a852b22a0239ec6c3042e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Oct 2025 16:31:09 -0400 Subject: [PATCH] Fix: Prevent false device mismatch warnings for cuda vs cuda:0 - Treat cuda and cuda:0 as compatible devices - Only warn on actual device mismatches (cuda vs cpu) - Eliminates warning spam during multi-subset training --- library/flux_train_utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index cfc646f0..a51d125a 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -494,13 +494,24 @@ def apply_cdc_noise_transformation( Transformed noise with geometry-aware covariance """ # Device consistency validation + # Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu" + target_device = torch.device(device) if not isinstance(device, torch.device) else device noise_device = noise.device - if str(noise_device) != str(device): + + # Check if devices are compatible (cuda:0 vs cuda should not warn) + devices_compatible = ( + noise_device == target_device or + (noise_device.type == "cuda" and target_device.type == "cuda") or + (noise_device.type == "cpu" and target_device.type == "cpu") + ) + + if not devices_compatible: logger.warning( - f"CDC device mismatch: noise on {noise_device} but CDC loading to {device}. " - f"Transferring noise to {device} to avoid errors." + f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. " + f"Transferring noise to {target_device} to avoid errors." ) - noise = noise.to(device) + noise = noise.to(target_device) + device = target_device # Normalize timesteps to [0, 1] for CDC-FM t_normalized = timesteps.to(device) / num_timesteps