Added --blocks_to_swap_while_sampling as an optional command line parameter.

This commit is contained in:
araleza
2025-04-20 11:32:07 +01:00
parent 5a18a03ffc
commit dde9936a34
3 changed files with 45 additions and 11 deletions

View File

@@ -194,18 +194,25 @@ class ModelOffloader(Offloader):
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: list[nn.Module], override_blocks_to_swap = None):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if self.debug:
print("Prepare block devices before forward")
# More blocks might be sent to the card when e.g. performing sample image generation
if override_blocks_to_swap is None or override_blocks_to_swap == -1: # 0 is a valid number of blocks to swap
blocks_to_swap = self.blocks_to_swap
else:
blocks_to_swap = override_blocks_to_swap
assert(override_blocks_to_swap <= self.blocks_to_swap, "Can't override with a greater number of blocks to swap.")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
if self.debug:
print(f"Prepare block devices before forward. Setting blocks to swap to {blocks_to_swap}.")
for b in blocks[0 : self.num_blocks - blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
for b in blocks[self.num_blocks - blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu

View File

@@ -1000,11 +1000,11 @@ class Flux(nn.Module):
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
def prepare_block_swap_before_forward(self, override_blocks_to_swap = None):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks, override_blocks_to_swap)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks, override_blocks_to_swap)
def forward(
self,

View File

@@ -235,7 +235,19 @@ def sample_image_inference(
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = denoise(
flux,
noise,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,
blocks_to_swap_while_sampling=args.blocks_to_swap_while_sampling)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -313,6 +325,7 @@ def denoise(
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
blocks_to_swap_while_sampling = None
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
@@ -320,7 +333,11 @@ def denoise(
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
# Due to (e.g.) the batch size while sampling being 1, it may be possible to have more
# blocks on the GPU, speeding up sample image generation
model.prepare_block_swap_before_forward(override_blocks_to_swap = blocks_to_swap_while_sampling)
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
@@ -351,7 +368,9 @@ def denoise(
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
# Return to the default number of blocks to swap rather than the number to use while sampling
model.prepare_block_swap_before_forward(override_blocks_to_swap = None)
return img
@@ -619,3 +638,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
# Flux blocks to swap between VRAM and main memory while generating sample images
parser.add_argument(
"--blocks_to_swap_while_sampling",
type=int,
default=-1,
help="It may be possible to get a larger number of Flux model blocks on the GPU while sampling due to the batch size of 1, and perhaps optimizer state not being required.",
)