mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
remove unused weight swapping functions from utils.py
This commit is contained in:
185
library/utils.py
185
library/utils.py
@@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False):
|
|||||||
|
|
||||||
# region PyTorch utils
|
# region PyTorch utils
|
||||||
|
|
||||||
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
||||||
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
||||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
||||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
||||||
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
|
|
||||||
# # cpu_tensor = module_to_cuda.weight.data
|
|
||||||
# # cuda_tensor = module_to_cpu.weight.data
|
|
||||||
# # assert cuda_tensor.device.type == "cuda"
|
|
||||||
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
|
|
||||||
# # torch.cuda.current_stream().synchronize()
|
|
||||||
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
|
|
||||||
# # torch.cuda.current_stream().synchronize()
|
|
||||||
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
|
|
||||||
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
|
|
||||||
# cuda_tensor_view = module_to_cpu.weight.data
|
|
||||||
# cpu_tensor_view = module_to_cuda.weight.data
|
|
||||||
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
|
|
||||||
# module_to_cuda.weight.data = cuda_tensor_view
|
|
||||||
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
|
|
||||||
|
|
||||||
|
|
||||||
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||||
@@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|||||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||||
|
|
||||||
|
|
||||||
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
||||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
||||||
|
|
||||||
weight_swap_jobs = []
|
|
||||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
||||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
||||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
|
||||||
|
|
||||||
stream_to_cpu = torch.cuda.Stream()
|
|
||||||
stream_to_cuda = torch.cuda.Stream()
|
|
||||||
|
|
||||||
events = []
|
|
||||||
with torch.cuda.stream(stream_to_cpu):
|
|
||||||
# cuda to offload
|
|
||||||
offloaded_weights = []
|
|
||||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
|
||||||
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
event.record(stream=stream_to_cpu)
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream_to_cuda):
|
|
||||||
# cpu to cuda
|
|
||||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
|
|
||||||
event.synchronize()
|
|
||||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
|
||||||
module_to_cuda.weight.data = cuda_data_view
|
|
||||||
|
|
||||||
# offload to cpu
|
|
||||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
|
|
||||||
weight_swap_jobs, offloaded_weights
|
|
||||||
):
|
|
||||||
module_to_cpu.weight.data = offloaded_weight
|
|
||||||
|
|
||||||
stream_to_cuda.synchronize()
|
|
||||||
|
|
||||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
|
||||||
|
|
||||||
|
|
||||||
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
||||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
||||||
|
|
||||||
weight_swap_jobs = []
|
|
||||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
||||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
||||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
|
||||||
|
|
||||||
stream_to_cpu = torch.cuda.Stream()
|
|
||||||
stream_to_cuda = torch.cuda.Stream()
|
|
||||||
|
|
||||||
# cuda to offload
|
|
||||||
events = []
|
|
||||||
with torch.cuda.stream(stream_to_cpu):
|
|
||||||
offloaded_weights = []
|
|
||||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
|
||||||
cuda_data_view.record_stream(stream_to_cpu)
|
|
||||||
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
|
||||||
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
event.record(stream=stream_to_cpu)
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
# cpu to cuda
|
|
||||||
with torch.cuda.stream(stream_to_cuda):
|
|
||||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
|
|
||||||
weight_swap_jobs, events, offloaded_weights
|
|
||||||
):
|
|
||||||
event.synchronize()
|
|
||||||
cuda_data_view.record_stream(stream_to_cuda)
|
|
||||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
|
||||||
module_to_cuda.weight.data = cuda_data_view
|
|
||||||
|
|
||||||
module_to_cpu.weight.data = offloaded_weight
|
|
||||||
|
|
||||||
stream_to_cuda.synchronize()
|
|
||||||
|
|
||||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
|
||||||
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
|
|
||||||
# for job in weight_swap_jobs:
|
|
||||||
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
|
|
||||||
|
|
||||||
|
|
||||||
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
||||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
||||||
|
|
||||||
weight_swap_jobs = []
|
|
||||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
||||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
||||||
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
|
|
||||||
# one of the modules must have the tensor to offload
|
|
||||||
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
|
||||||
module_to_cpu.offloaded_weight.pin_memory()
|
|
||||||
offloaded_weight = (
|
|
||||||
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
|
|
||||||
)
|
|
||||||
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
|
|
||||||
weight_swap_jobs.append(
|
|
||||||
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
|
|
||||||
)
|
|
||||||
|
|
||||||
stream = torch.cuda.Stream()
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# cuda to offload
|
|
||||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
|
||||||
cuda_data_view.record_stream(stream)
|
|
||||||
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
|
|
||||||
|
|
||||||
stream.synchronize()
|
|
||||||
|
|
||||||
# cpu to cuda
|
|
||||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
|
||||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
|
||||||
module_to_cuda.weight.data = cuda_data_view
|
|
||||||
|
|
||||||
# offload to cpu
|
|
||||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
|
||||||
module_to_cpu.weight.data = offloaded_weight
|
|
||||||
offloaded_weight = cpu_data_view
|
|
||||||
module_to_cpu.offloaded_weight = offloaded_weight
|
|
||||||
module_to_cuda.offloaded_weight = offloaded_weight
|
|
||||||
|
|
||||||
stream.synchronize()
|
|
||||||
|
|
||||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
|
||||||
|
|
||||||
|
|
||||||
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
||||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
||||||
|
|
||||||
weight_swap_jobs = []
|
|
||||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
||||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
||||||
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
|
|
||||||
# one of the modules must have the tensor to cache
|
|
||||||
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
|
||||||
module_to_cpu.__cached_cpu_weight.pin_memory()
|
|
||||||
|
|
||||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
|
||||||
|
|
||||||
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
|
|
||||||
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
|
|
||||||
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
|
|
||||||
|
|
||||||
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
||||||
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
||||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
||||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
||||||
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
|
|
||||||
# weight_on_cuda = module_to_cpu.weight
|
|
||||||
# weight_on_cpu = module_to_cuda.weight
|
|
||||||
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
|
|
||||||
# event = torch.cuda.current_stream().record_event()
|
|
||||||
# event.synchronize()
|
|
||||||
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
|
|
||||||
# weight_on_cpu.data = cuda_to_cpu_data
|
|
||||||
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
|
|
||||||
|
|
||||||
# module_to_cpu.weight = weight_on_cpu
|
|
||||||
# module_to_cuda.weight = weight_on_cuda
|
|
||||||
|
|
||||||
|
|
||||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||||
for module in layer.modules():
|
for module in layer.modules():
|
||||||
if hasattr(module, "weight") and module.weight is not None:
|
if hasattr(module, "weight") and module.weight is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user