diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 84c2b743..8936595d 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -194,18 +194,25 @@ class ModelOffloader(Offloader): return backward_hook - def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + def prepare_block_devices_before_forward(self, blocks: list[nn.Module], override_blocks_to_swap = None): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return - if self.debug: - print("Prepare block devices before forward") + # More blocks might be sent to the card when e.g. performing sample image generation + if override_blocks_to_swap is None or override_blocks_to_swap == -1: # 0 is a valid number of blocks to swap + blocks_to_swap = self.blocks_to_swap + else: + blocks_to_swap = override_blocks_to_swap + assert(override_blocks_to_swap <= self.blocks_to_swap, "Can't override with a greater number of blocks to swap.") - for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + if self.debug: + print(f"Prepare block devices before forward. Setting blocks to swap to {blocks_to_swap}.") + + for b in blocks[0 : self.num_blocks - blocks_to_swap]: b.to(self.device) weighs_to_device(b, self.device) # make sure weights are on device - for b in blocks[self.num_blocks - self.blocks_to_swap :]: + for b in blocks[self.num_blocks - blocks_to_swap :]: b.to(self.device) # move block to device first weighs_to_device(b, "cpu") # make sure weights are on cpu diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..ec5f6d9d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1000,11 +1000,11 @@ class Flux(nn.Module): self.double_blocks = save_double_blocks self.single_blocks = save_single_blocks - def prepare_block_swap_before_forward(self): + def prepare_block_swap_before_forward(self, override_blocks_to_swap = None): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return - self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks, override_blocks_to_swap) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks, override_blocks_to_swap) def forward( self, diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0e73a01d..2a6b1f2d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,19 @@ def sample_image_inference( controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + x = denoise( + flux, + noise, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps=timesteps, + guidance=scale, + t5_attn_mask=t5_attn_mask, + controlnet=controlnet, + controlnet_img=controlnet_image, + blocks_to_swap_while_sampling=args.blocks_to_swap_while_sampling) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -313,6 +325,7 @@ def denoise( t5_attn_mask: Optional[torch.Tensor] = None, controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet_img: Optional[torch.Tensor] = None, + blocks_to_swap_while_sampling = None ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) @@ -320,7 +333,11 @@ def denoise( for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - model.prepare_block_swap_before_forward() + + # Due to (e.g.) the batch size while sampling being 1, it may be possible to have more + # blocks on the GPU, speeding up sample image generation + model.prepare_block_swap_before_forward(override_blocks_to_swap = blocks_to_swap_while_sampling) + if controlnet is not None: block_samples, block_single_samples = controlnet( img=img, @@ -351,7 +368,9 @@ def denoise( img = img + (t_prev - t_curr) * pred - model.prepare_block_swap_before_forward() + # Return to the default number of blocks to swap rather than the number to use while sampling + model.prepare_block_swap_before_forward(override_blocks_to_swap = None) + return img @@ -619,3 +638,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + # Flux blocks to swap between VRAM and main memory while generating sample images + parser.add_argument( + "--blocks_to_swap_while_sampling", + type=int, + default=-1, + help="It may be possible to get a larger number of Flux model blocks on the GPU while sampling due to the batch size of 1, and perhaps optimizer state not being required.", + )