mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix: Prevent false device mismatch warnings for cuda vs cuda:0
- Treat cuda and cuda:0 as compatible devices - Only warn on actual device mismatches (cuda vs cpu) - Eliminates warning spam during multi-subset training
This commit is contained in:
@@ -494,13 +494,24 @@ def apply_cdc_noise_transformation(
|
||||
Transformed noise with geometry-aware covariance
|
||||
"""
|
||||
# Device consistency validation
|
||||
# Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu"
|
||||
target_device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
noise_device = noise.device
|
||||
if str(noise_device) != str(device):
|
||||
|
||||
# Check if devices are compatible (cuda:0 vs cuda should not warn)
|
||||
devices_compatible = (
|
||||
noise_device == target_device or
|
||||
(noise_device.type == "cuda" and target_device.type == "cuda") or
|
||||
(noise_device.type == "cpu" and target_device.type == "cpu")
|
||||
)
|
||||
|
||||
if not devices_compatible:
|
||||
logger.warning(
|
||||
f"CDC device mismatch: noise on {noise_device} but CDC loading to {device}. "
|
||||
f"Transferring noise to {device} to avoid errors."
|
||||
f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. "
|
||||
f"Transferring noise to {target_device} to avoid errors."
|
||||
)
|
||||
noise = noise.to(device)
|
||||
noise = noise.to(target_device)
|
||||
device = target_device
|
||||
|
||||
# Normalize timesteps to [0, 1] for CDC-FM
|
||||
t_normalized = timesteps.to(device) / num_timesteps
|
||||
|
||||
Reference in New Issue
Block a user