diff --git a/library/flux_models.py b/library/flux_models.py index d2d7e06c..54d7bdd5 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -971,8 +971,8 @@ class Flux(nn.Module): def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - double_blocks_to_swap = num_blocks // 2 - single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + double_blocks_to_swap = min(self.num_double_blocks - 2, num_blocks // 2) + single_blocks_to_swap = (num_blocks - (num_blocks // 2)) * 2 assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "