From 4e2a80a6caa546f44a3667a7d9dec6a2c6378591 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 13 Sep 2025 21:07:11 +0900 Subject: [PATCH] refactor: update imports to use safetensors_utils for memory-efficient operations --- hunyuan_image_minimal_inference.py | 3 ++- library/custom_offloading_utils.py | 15 ++++++++------- library/fp8_optimization_utils.py | 3 ++- library/hunyuan_image_text_encoder.py | 4 ++-- library/hunyuan_image_vae.py | 3 ++- library/lora_utils.py | 3 ++- networks/flux_extract_lora.py | 5 ++--- 7 files changed, 20 insertions(+), 16 deletions(-) diff --git a/hunyuan_image_minimal_inference.py b/hunyuan_image_minimal_inference.py index 3de0b1cd..7db490cd 100644 --- a/hunyuan_image_minimal_inference.py +++ b/hunyuan_image_minimal_inference.py @@ -22,6 +22,7 @@ from library import hunyuan_image_models, hunyuan_image_text_encoder, hunyuan_im from library import hunyuan_image_vae from library.hunyuan_image_vae import HunyuanVAE2D from library.device_utils import clean_memory_on_device, synchronize_device +from library.safetensors_utils import mem_eff_save_file from networks import lora_hunyuan_image @@ -29,7 +30,7 @@ lycoris_available = find_spec("lycoris") is not None if lycoris_available: from lycoris.kohya import create_network_from_weights -from library.utils import mem_eff_save_file, setup_logging +from library.utils import setup_logging setup_logging() import logging diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 1b7bbc14..fe7e59d2 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -173,14 +173,12 @@ class ModelOffloader(Offloader): """ def __init__( - self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, supports_backward: bool = True, debug: bool = False, - ): super().__init__(len(blocks), blocks_to_swap, device, debug) @@ -220,7 +218,7 @@ class ModelOffloader(Offloader): block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated block_idx_to_wait = block_index - 1 - def backward_hook(module, grad_input, grad_output): + def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t): if self.debug: print(f"Backward hook for block {block_index}") @@ -232,7 +230,7 @@ class ModelOffloader(Offloader): return backward_hook - def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return @@ -245,7 +243,7 @@ class ModelOffloader(Offloader): for b in blocks[self.num_blocks - self.blocks_to_swap :]: b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device - weighs_to_device(b, "cpu") # make sure weights are on cpu + weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu _synchronize_device(self.device) _clean_memory_on_device(self.device) @@ -255,7 +253,7 @@ class ModelOffloader(Offloader): return self._wait_blocks_move(block_idx) - def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int): # check if blocks_to_swap is enabled if self.blocks_to_swap is None or self.blocks_to_swap == 0: return @@ -266,7 +264,10 @@ class ModelOffloader(Offloader): block_idx_to_cpu = block_idx block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx - block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading + + # this works for forward-only offloading. move upstream blocks to cuda + block_idx_to_cuda = block_idx_to_cuda % self.num_blocks + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py index a91eb4e4..ed7d3f76 100644 --- a/library/fp8_optimization_utils.py +++ b/library/fp8_optimization_utils.py @@ -9,7 +9,8 @@ import logging from tqdm import tqdm from library.device_utils import clean_memory_on_device -from library.utils import MemoryEfficientSafeOpen, setup_logging +from library.safetensors_utils import MemoryEfficientSafeOpen +from library.utils import setup_logging setup_logging() import logging diff --git a/library/hunyuan_image_text_encoder.py b/library/hunyuan_image_text_encoder.py index 960f14b3..509f9bd2 100644 --- a/library/hunyuan_image_text_encoder.py +++ b/library/hunyuan_image_text_encoder.py @@ -14,8 +14,8 @@ from transformers import ( from transformers.models.t5.modeling_t5 import T5Stack from accelerate import init_empty_weights -from library import model_util -from library.utils import load_safetensors, setup_logging +from library.safetensors_utils import load_safetensors +from library.utils import setup_logging setup_logging() import logging diff --git a/library/hunyuan_image_vae.py b/library/hunyuan_image_vae.py index 570d4caa..6f6eea22 100644 --- a/library/hunyuan_image_vae.py +++ b/library/hunyuan_image_vae.py @@ -7,7 +7,8 @@ from torch import Tensor, nn from torch.nn import Conv2d from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from library.utils import load_safetensors, setup_logging +from library.safetensors_utils import load_safetensors +from library.utils import setup_logging setup_logging() import logging diff --git a/library/lora_utils.py b/library/lora_utils.py index 468fb01a..b93eb9af 100644 --- a/library/lora_utils.py +++ b/library/lora_utils.py @@ -9,7 +9,8 @@ from tqdm import tqdm from library.device_utils import synchronize_device from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization -from library.utils import MemoryEfficientSafeOpen, setup_logging +from library.safetensors_utils import MemoryEfficientSafeOpen +from library.utils import setup_logging setup_logging() import logging diff --git a/networks/flux_extract_lora.py b/networks/flux_extract_lora.py index 63ab2960..65728702 100644 --- a/networks/flux_extract_lora.py +++ b/networks/flux_extract_lora.py @@ -10,9 +10,8 @@ import torch from safetensors.torch import load_file, save_file from safetensors import safe_open from tqdm import tqdm -from library import flux_utils, sai_model_spec, model_util, sdxl_model_util -import lora -from library.utils import MemoryEfficientSafeOpen +from library import flux_utils, sai_model_spec +from library.safetensors_utils import MemoryEfficientSafeOpen from library.utils import setup_logging from networks import lora_flux