Compare commits

...

2 Commits

Author SHA1 Message Date
Kohya S
0e8b5d4af9 Merge branch 'sd3' into no-grad-block-swap 2025-05-24 18:50:25 +09:00
kohya-ss
c898e4e536 fix: optimize weight device swapping with no_grad context 2025-03-17 21:28:02 +09:00

View File

@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream() with torch.no_grad():
with torch.cuda.stream(stream): stream = torch.cuda.Stream()
# cuda to cpu with torch.cuda.stream(stream):
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: # cuda to cpu
cuda_data_view.record_stream(stream) for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize() stream.synchronize()
# cpu to cuda # cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view module_to_cuda.weight.data = cuda_data_view
stream.synchronize() stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
@@ -75,14 +76,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device() synchronize_device(device)
# cpu to device # cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view module_to_cuda.weight.data = cuda_data_view
synchronize_device() synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device): def weighs_to_device(layer: nn.Module, device: torch.device):