fix: wait all block trasfer before siwtching offloader mode

This commit is contained in:
Kohya S
2026-02-10 21:46:58 +09:00
parent 4992aae311
commit dbb40ae4c0
3 changed files with 10 additions and 0 deletions

View File

@@ -86,6 +86,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
# Load VAE
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
vae.to(weight_dtype)
vae.eval()
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily

View File

@@ -495,6 +495,7 @@ def sample_images(
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
dit.prepare_block_swap_before_forward()
_sample_image_inference(
accelerator,
args,

View File

@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
self.remove_handles.append(handle)
def set_forward_only(self, forward_only: bool):
# switching must wait for all pending transfers
for block_idx in list(self.futures.keys()):
self._wait_blocks_move(block_idx)
self.forward_only = forward_only
def __del__(self):
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
if self.debug:
print(f"Prepare block devices before forward")
# wait for all pending transfers
for block_idx in list(self.futures.keys()):
self._wait_blocks_move(block_idx)
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device