mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Added --blocks_to_swap_while_sampling as an optional command line parameter.
This commit is contained in:
@@ -194,18 +194,25 @@ class ModelOffloader(Offloader):
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module], override_blocks_to_swap = None):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
if self.debug:
|
||||
print("Prepare block devices before forward")
|
||||
# More blocks might be sent to the card when e.g. performing sample image generation
|
||||
if override_blocks_to_swap is None or override_blocks_to_swap == -1: # 0 is a valid number of blocks to swap
|
||||
blocks_to_swap = self.blocks_to_swap
|
||||
else:
|
||||
blocks_to_swap = override_blocks_to_swap
|
||||
assert(override_blocks_to_swap <= self.blocks_to_swap, "Can't override with a greater number of blocks to swap.")
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
if self.debug:
|
||||
print(f"Prepare block devices before forward. Setting blocks to swap to {blocks_to_swap}.")
|
||||
|
||||
for b in blocks[0 : self.num_blocks - blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
||||
for b in blocks[self.num_blocks - blocks_to_swap :]:
|
||||
b.to(self.device) # move block to device first
|
||||
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
||||
|
||||
|
||||
@@ -1000,11 +1000,11 @@ class Flux(nn.Module):
|
||||
self.double_blocks = save_double_blocks
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
def prepare_block_swap_before_forward(self, override_blocks_to_swap = None):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks, override_blocks_to_swap)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks, override_blocks_to_swap)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -235,7 +235,19 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance=scale,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
controlnet=controlnet,
|
||||
controlnet_img=controlnet_image,
|
||||
blocks_to_swap_while_sampling=args.blocks_to_swap_while_sampling)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -313,6 +325,7 @@ def denoise(
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
blocks_to_swap_while_sampling = None
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
@@ -320,7 +333,11 @@ def denoise(
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
# Due to (e.g.) the batch size while sampling being 1, it may be possible to have more
|
||||
# blocks on the GPU, speeding up sample image generation
|
||||
model.prepare_block_swap_before_forward(override_blocks_to_swap = blocks_to_swap_while_sampling)
|
||||
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
@@ -351,7 +368,9 @@ def denoise(
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
# Return to the default number of blocks to swap rather than the number to use while sampling
|
||||
model.prepare_block_swap_before_forward(override_blocks_to_swap = None)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
@@ -619,3 +638,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
|
||||
# Flux blocks to swap between VRAM and main memory while generating sample images
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap_while_sampling",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="It may be possible to get a larger number of Flux model blocks on the GPU while sampling due to the batch size of 1, and perhaps optimizer state not being required.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user