Fix: Prevent false device mismatch warnings for cuda vs cuda:0

- Treat cuda and cuda:0 as compatible devices
- Only warn on actual device mismatches (cuda vs cpu)
- Eliminates warning spam during multi-subset training
This commit is contained in:
rockerBOO
2025-10-09 16:31:09 -04:00
parent ee8ceee178
commit 4bea582601

View File

@@ -494,13 +494,24 @@ def apply_cdc_noise_transformation(
Transformed noise with geometry-aware covariance
"""
# Device consistency validation
# Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu"
target_device = torch.device(device) if not isinstance(device, torch.device) else device
noise_device = noise.device
if str(noise_device) != str(device):
# Check if devices are compatible (cuda:0 vs cuda should not warn)
devices_compatible = (
noise_device == target_device or
(noise_device.type == "cuda" and target_device.type == "cuda") or
(noise_device.type == "cpu" and target_device.type == "cpu")
)
if not devices_compatible:
logger.warning(
f"CDC device mismatch: noise on {noise_device} but CDC loading to {device}. "
f"Transferring noise to {device} to avoid errors."
f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. "
f"Transferring noise to {target_device} to avoid errors."
)
noise = noise.to(device)
noise = noise.to(target_device)
device = target_device
# Normalize timesteps to [0, 1] for CDC-FM
t_normalized = timesteps.to(device) / num_timesteps