diff --git a/anima_train_network.py b/anima_train_network.py index ed5dc2ba..812fda7d 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -96,7 +96,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]: - loading_dtype = None if args.fp8_base else weight_dtype + loading_dtype = None if args.fp8_scaled else weight_dtype loading_device = "cpu" if self.is_swapping_blocks else accelerator.device attn_mode = "torch" diff --git a/library/anima_utils.py b/library/anima_utils.py index d9791fc7..213188e5 100644 --- a/library/anima_utils.py +++ b/library/anima_utils.py @@ -123,7 +123,8 @@ def load_anima_dit( FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""] -FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer"] +# ".embed." excludes Embedding in LLMAdapter +FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."] def load_anima_model(