From af5adb2b61f59710a36f5e3edd0ac567f030aeba Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Mar 2025 21:28:08 -0400 Subject: [PATCH 1/2] Add flexibility to block swapping for Flex model --- library/flux_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..524af997 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -968,7 +968,7 @@ class Flux(nn.Module): def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks - double_blocks_to_swap = num_blocks // 2 + double_blocks_to_swap = min(self.num_double_blocks - 2, num_blocks // 2) single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( From dd9a3308604b762cc1b3871142046646f5c33e83 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 23 Mar 2025 21:30:45 -0400 Subject: [PATCH 2/2] Change single blocks to use num_blocks value --- library/flux_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/flux_models.py b/library/flux_models.py index 524af997..2e77b944 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -969,7 +969,7 @@ class Flux(nn.Module): def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks double_blocks_to_swap = min(self.num_double_blocks - 2, num_blocks // 2) - single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + single_blocks_to_swap = (num_blocks - (num_blocks // 2)) * 2 assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "