fix: sample generation doesn't work with block swap

This commit is contained in:
Kohya S
2025-09-13 21:06:04 +09:00
parent bae7fa74eb
commit d831c88832
2 changed files with 19 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
import argparse
import copy
import gc
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
import argparse
import os
import time
@@ -47,7 +47,7 @@ def sample_images(
args: argparse.Namespace,
epoch,
steps,
dit,
dit: hunyuan_image_models.HYImageDiffusionTransformer,
vae,
text_encoders,
sample_prompts_te_outputs,
@@ -77,6 +77,8 @@ def sample_images(
# unwrap unet and text_encoder(s)
dit = accelerator.unwrap_model(dit)
dit = cast(hunyuan_image_models.HYImageDiffusionTransformer, dit)
dit.switch_block_swap_for_inference()
if text_encoders is not None:
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
@@ -139,6 +141,7 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)

View File

@@ -185,6 +185,20 @@ class HYImageDiffusionTransformer(nn.Module):
f"HunyuanImage-2.1: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def switch_block_swap_for_inference(self):
if self.blocks_to_swap:
self.offloader_double.set_forward_only(True)
self.offloader_single.set_forward_only(True)
self.prepare_block_swap_before_forward()
print(f"HunyuanImage-2.1: Block swap set to forward only.")
def switch_block_swap_for_training(self):
if self.blocks_to_swap:
self.offloader_double.set_forward_only(False)
self.offloader_single.set_forward_only(False)
self.prepare_block_swap_before_forward()
print(f"HunyuanImage-2.1: Block swap set to forward and backward.")
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap: