From b67cc5a45791210a925246afce08e24860fe8678 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:56:23 +0900 Subject: [PATCH] feat: update Anima SAI model spec metadata handling --- anima_train_network.py | 2 +- library/anima_train_utils.py | 12 ++++++++---- library/anima_utils.py | 14 +++++++++++--- library/sai_model_spec.py | 13 ++++++++++++- library/train_util.py | 4 +++- 5 files changed, 35 insertions(+), 10 deletions(-) diff --git a/anima_train_network.py b/anima_train_network.py index 860da9a2..4dea15ec 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -375,7 +375,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True) + return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict() def update_metadata(self, metadata, args): metadata["ss_weighting_scheme"] = args.weighting_scheme diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index 3b94f952..11996099 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -312,10 +312,12 @@ def save_anima_model_on_train_end( """Save Anima model at the end of training.""" def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) + sai_metadata = train_util.get_sai_model_spec_dataclass( + None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview" + ).to_metadata_dict() dit_sd = dit.state_dict() # Save with 'net.' prefix for ComfyUI compatibility - anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) + anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype) train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) @@ -333,9 +335,11 @@ def save_anima_model_on_epoch_end_or_stepwise( """Save Anima model at epoch end or specific steps.""" def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) + sai_metadata = train_util.get_sai_model_spec_dataclass( + None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview" + ).to_metadata_dict() dit_sd = dit.state_dict() - anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) + anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype) train_util.save_sd_model_on_epoch_end_or_stepwise_common( args, diff --git a/library/anima_utils.py b/library/anima_utils.py index d5bb24df..e24d1813 100644 --- a/library/anima_utils.py +++ b/library/anima_utils.py @@ -407,21 +407,29 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None): ) -def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None): +def save_anima_model( + save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None +): """Save Anima DiT model with 'net.' prefix for ComfyUI compatibility. Args: save_path: Output path (.safetensors) dit_state_dict: State dict from dit.state_dict() + metadata: Metadata dict to include in the safetensors file dtype: Optional dtype to cast to before saving """ prefixed_sd = {} for k, v in dit_state_dict.items(): if dtype is not None: - v = v.to(dtype) + # v = v.to(dtype) + v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save prefixed_sd["net." + k] = v.contiguous() - save_file(prefixed_sd, save_path, metadata={"format": "pt"}) + if metadata is None: + metadata = {} + metadata["format"] = "pt" # For compatibility with the official .safetensors file + + save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough logger.info(f"Saved Anima model to {save_path}") diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 32a4fd7b..0ac9b3be 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2" ARCH_LUMINA_UNKNOWN = "lumina" ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1" ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image" +ARCH_ANIMA_PREVIEW = "anima-preview" +ARCH_ANIMA_UNKNOWN = "anima-unknown" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" @@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux" IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma" IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0" IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1" +IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -220,6 +223,12 @@ def determine_architecture( arch = ARCH_HUNYUAN_IMAGE_2_1 else: arch = ARCH_HUNYUAN_IMAGE_UNKNOWN + elif "anima" in model_config: + anima_type = model_config["anima"] + if anima_type == "preview": + arch = ARCH_ANIMA_PREVIEW + else: + arch = ARCH_ANIMA_UNKNOWN elif v2: arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512 else: @@ -252,6 +261,8 @@ def determine_implementation( return IMPL_FLUX elif "lumina" in model_config: return IMPL_LUMINA + elif "anima" in model_config: + return IMPL_ANIMA elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: return IMPL_STABILITY_AI else: @@ -325,7 +336,7 @@ def determine_resolution( reso = (reso[0], reso[0]) else: # Determine default resolution based on model type - if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config: + if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config: reso = (1024, 1024) elif v2 and v_parameterization: reso = (768, 768) diff --git a/library/train_util.py b/library/train_util.py index 7d410b79..d8577b9d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3581,6 +3581,7 @@ def get_sai_model_spec_dataclass( flux: str = None, lumina: str = None, hunyuan_image: str = None, + anima: str = None, optional_metadata: dict[str, str] | None = None, ) -> sai_model_spec.ModelSpecMetadata: """ @@ -3612,7 +3613,8 @@ def get_sai_model_spec_dataclass( model_config["lumina"] = lumina if hunyuan_image is not None: model_config["hunyuan_image"] = hunyuan_image - + if anima is not None: + model_config["anima"] = anima # Use the dataclass function directly return sai_model_spec.build_metadata_dataclass( state_dict,