This commit is contained in:
Dave Lage
2026-02-24 03:16:04 +01:00
committed by GitHub
9 changed files with 77 additions and 57 deletions

View File

@@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value