reduce peak VRAM usage by excluding some blocks to cuda

This commit is contained in:
Kohya S
2024-08-19 21:55:28 +09:00
parent d034032a5d
commit 6e72a799c8
2 changed files with 25 additions and 6 deletions

View File

@@ -953,6 +953,22 @@ class Flux(nn.Module):
self.double_blocks_to_swap = double_blocks
self.single_blocks_to_swap = single_blocks
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
if self.double_blocks_to_swap:
save_double_blocks = self.double_blocks
self.double_blocks = None
if self.single_blocks_to_swap:
save_single_blocks = self.single_blocks
self.single_blocks = None
self.to(device)
if self.double_blocks_to_swap:
self.double_blocks = save_double_blocks
if self.single_blocks_to_swap:
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
# move last n blocks to cpu: they are on cuda
if self.double_blocks_to_swap: