mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
fix: sample generation doesn't work with block swap
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
@@ -47,7 +47,7 @@ def sample_images(
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit,
|
||||
dit: hunyuan_image_models.HYImageDiffusionTransformer,
|
||||
vae,
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
@@ -77,6 +77,8 @@ def sample_images(
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
dit = cast(hunyuan_image_models.HYImageDiffusionTransformer, dit)
|
||||
dit.switch_block_swap_for_inference()
|
||||
if text_encoders is not None:
|
||||
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
@@ -139,6 +141,7 @@ def sample_images(
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
dit.switch_block_swap_for_training()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
|
||||
@@ -185,6 +185,20 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
f"HunyuanImage-2.1: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
)
|
||||
|
||||
def switch_block_swap_for_inference(self):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.set_forward_only(True)
|
||||
self.offloader_single.set_forward_only(True)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"HunyuanImage-2.1: Block swap set to forward only.")
|
||||
|
||||
def switch_block_swap_for_training(self):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.set_forward_only(False)
|
||||
self.offloader_single.set_forward_only(False)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"HunyuanImage-2.1: Block swap set to forward and backward.")
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
|
||||
Reference in New Issue
Block a user