Fix validation block swap. Add custom offloading tests

This commit is contained in:
rockerBOO
2025-02-27 20:36:36 -05:00
parent 42fe22f5a2
commit 9647f1e324
7 changed files with 446 additions and 32 deletions

View File

@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
self.joint_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.joint_blocks = nn.ModuleList()
self.to(device)