mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
fix: wait all block trasfer before siwtching offloader mode
This commit is contained in:
@@ -86,6 +86,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# Load VAE
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae.to(weight_dtype)
|
||||
vae.eval()
|
||||
|
||||
# Return format: (model_type, text_encoders, vae, unet)
|
||||
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
|
||||
|
||||
@@ -495,6 +495,7 @@ def sample_images(
|
||||
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
dit.prepare_block_swap_before_forward()
|
||||
_sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
|
||||
@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def set_forward_only(self, forward_only: bool):
|
||||
# switching must wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
self.forward_only = forward_only
|
||||
|
||||
def __del__(self):
|
||||
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
|
||||
if self.debug:
|
||||
print(f"Prepare block devices before forward")
|
||||
|
||||
# wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
Reference in New Issue
Block a user