fix: update to work fp8_scaled option

This commit is contained in:
Kohya S
2026-02-09 23:31:33 +09:00
parent 0f413974b7
commit 58db77a488
2 changed files with 3 additions and 2 deletions

View File

@@ -96,7 +96,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
loading_dtype = None if args.fp8_base else weight_dtype
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"

View File

@@ -123,7 +123,8 @@ def load_anima_dit(
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer"]
# ".embed." excludes Embedding in LLMAdapter
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
def load_anima_model(