diff --git a/anima_train_network.py b/anima_train_network.py index ad4c771c..dd5f85e6 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -86,6 +86,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Load VAE logger.info("Loading Anima VAE...") vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae.to(weight_dtype) + vae.eval() # Return format: (model_type, text_encoders, vae, unet) return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index 29b75188..3b94f952 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -495,6 +495,7 @@ def sample_images( with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: + dit.prepare_block_swap_before_forward() _sample_image_inference( accelerator, args, diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 0681dcdc..883379ce 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -195,6 +195,9 @@ class ModelOffloader(Offloader): self.remove_handles.append(handle) def set_forward_only(self, forward_only: bool): + # switching must wait for all pending transfers + for block_idx in list(self.futures.keys()): + self._wait_blocks_move(block_idx) self.forward_only = forward_only def __del__(self): @@ -237,6 +240,10 @@ class ModelOffloader(Offloader): if self.debug: print(f"Prepare block devices before forward") + # wait for all pending transfers + for block_idx in list(self.futures.keys()): + self._wait_blocks_move(block_idx) + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: b.to(self.device) weighs_to_device(b, self.device) # make sure weights are on device