feat: update Anima SAI model spec metadata handling

This commit is contained in:
Kohya S
2026-02-11 14:56:23 +09:00
parent 90725eba64
commit b67cc5a457
5 changed files with 35 additions and 10 deletions

View File

@@ -375,7 +375,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme

View File

@@ -312,10 +312,12 @@ def save_anima_model_on_train_end(
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@@ -333,9 +335,11 @@ def save_anima_model_on_epoch_end_or_stepwise(
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,

View File

@@ -407,21 +407,29 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
)
def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None):
def save_anima_model(
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
):
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
Args:
save_path: Output path (.safetensors)
dit_state_dict: State dict from dit.state_dict()
metadata: Metadata dict to include in the safetensors file
dtype: Optional dtype to cast to before saving
"""
prefixed_sd = {}
for k, v in dit_state_dict.items():
if dtype is not None:
v = v.to(dtype)
# v = v.to(dtype)
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
prefixed_sd["net." + k] = v.contiguous()
save_file(prefixed_sd, save_path, metadata={"format": "pt"})
if metadata is None:
metadata = {}
metadata["format"] = "pt" # For compatibility with the official .safetensors file
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
logger.info(f"Saved Anima model to {save_path}")

View File

@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
ARCH_ANIMA_PREVIEW = "anima-preview"
ARCH_ANIMA_UNKNOWN = "anima-unknown"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -220,6 +223,12 @@ def determine_architecture(
arch = ARCH_HUNYUAN_IMAGE_2_1
else:
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
elif "anima" in model_config:
anima_type = model_config["anima"]
if anima_type == "preview":
arch = ARCH_ANIMA_PREVIEW
else:
arch = ARCH_ANIMA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
@@ -252,6 +261,8 @@ def determine_implementation(
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
elif "anima" in model_config:
return IMPL_ANIMA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
@@ -325,7 +336,7 @@ def determine_resolution(
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)

View File

@@ -3581,6 +3581,7 @@ def get_sai_model_spec_dataclass(
flux: str = None,
lumina: str = None,
hunyuan_image: str = None,
anima: str = None,
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
@@ -3612,7 +3613,8 @@ def get_sai_model_spec_dataclass(
model_config["lumina"] = lumina
if hunyuan_image is not None:
model_config["hunyuan_image"] = hunyuan_image
if anima is not None:
model_config["anima"] = anima
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,