mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
reduce peak VRAM usage by excluding some blocks to cuda
This commit is contained in:
@@ -953,6 +953,22 @@ class Flux(nn.Module):
|
||||
self.double_blocks_to_swap = double_blocks
|
||||
self.single_blocks_to_swap = single_blocks
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu
|
||||
if self.double_blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
self.double_blocks = None
|
||||
if self.single_blocks_to_swap:
|
||||
save_single_blocks = self.single_blocks
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
if self.double_blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
if self.single_blocks_to_swap:
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# move last n blocks to cpu: they are on cuda
|
||||
if self.double_blocks_to_swap:
|
||||
|
||||
Reference in New Issue
Block a user