mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Fix: Prevent false device mismatch warnings for cuda vs cuda:0
- Treat cuda and cuda:0 as compatible devices - Only warn on actual device mismatches (cuda vs cpu) - Eliminates warning spam during multi-subset training
This commit is contained in:
@@ -494,13 +494,24 @@ def apply_cdc_noise_transformation(
|
|||||||
Transformed noise with geometry-aware covariance
|
Transformed noise with geometry-aware covariance
|
||||||
"""
|
"""
|
||||||
# Device consistency validation
|
# Device consistency validation
|
||||||
|
# Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu"
|
||||||
|
target_device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||||
noise_device = noise.device
|
noise_device = noise.device
|
||||||
if str(noise_device) != str(device):
|
|
||||||
|
# Check if devices are compatible (cuda:0 vs cuda should not warn)
|
||||||
|
devices_compatible = (
|
||||||
|
noise_device == target_device or
|
||||||
|
(noise_device.type == "cuda" and target_device.type == "cuda") or
|
||||||
|
(noise_device.type == "cpu" and target_device.type == "cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not devices_compatible:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"CDC device mismatch: noise on {noise_device} but CDC loading to {device}. "
|
f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. "
|
||||||
f"Transferring noise to {device} to avoid errors."
|
f"Transferring noise to {target_device} to avoid errors."
|
||||||
)
|
)
|
||||||
noise = noise.to(device)
|
noise = noise.to(target_device)
|
||||||
|
device = target_device
|
||||||
|
|
||||||
# Normalize timesteps to [0, 1] for CDC-FM
|
# Normalize timesteps to [0, 1] for CDC-FM
|
||||||
t_normalized = timesteps.to(device) / num_timesteps
|
t_normalized = timesteps.to(device) / num_timesteps
|
||||||
|
|||||||
Reference in New Issue
Block a user