mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
feat: update Anima SAI model spec metadata handling
This commit is contained in:
@@ -375,7 +375,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
|
||||
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
|
||||
@@ -312,10 +312,12 @@ def save_anima_model_on_train_end(
|
||||
"""Save Anima model at the end of training."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
# Save with 'net.' prefix for ComfyUI compatibility
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
@@ -333,9 +335,11 @@ def save_anima_model_on_epoch_end_or_stepwise(
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
|
||||
@@ -407,21 +407,29 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
||||
)
|
||||
|
||||
|
||||
def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None):
|
||||
def save_anima_model(
|
||||
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
|
||||
|
||||
Args:
|
||||
save_path: Output path (.safetensors)
|
||||
dit_state_dict: State dict from dit.state_dict()
|
||||
metadata: Metadata dict to include in the safetensors file
|
||||
dtype: Optional dtype to cast to before saving
|
||||
"""
|
||||
prefixed_sd = {}
|
||||
for k, v in dit_state_dict.items():
|
||||
if dtype is not None:
|
||||
v = v.to(dtype)
|
||||
# v = v.to(dtype)
|
||||
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
|
||||
prefixed_sd["net." + k] = v.contiguous()
|
||||
|
||||
save_file(prefixed_sd, save_path, metadata={"format": "pt"})
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["format"] = "pt" # For compatibility with the official .safetensors file
|
||||
|
||||
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
|
||||
logger.info(f"Saved Anima model to {save_path}")
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
|
||||
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
|
||||
ARCH_ANIMA_PREVIEW = "anima-preview"
|
||||
ARCH_ANIMA_UNKNOWN = "anima-unknown"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
|
||||
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -220,6 +223,12 @@ def determine_architecture(
|
||||
arch = ARCH_HUNYUAN_IMAGE_2_1
|
||||
else:
|
||||
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
|
||||
elif "anima" in model_config:
|
||||
anima_type = model_config["anima"]
|
||||
if anima_type == "preview":
|
||||
arch = ARCH_ANIMA_PREVIEW
|
||||
else:
|
||||
arch = ARCH_ANIMA_UNKNOWN
|
||||
elif v2:
|
||||
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
|
||||
else:
|
||||
@@ -252,6 +261,8 @@ def determine_implementation(
|
||||
return IMPL_FLUX
|
||||
elif "lumina" in model_config:
|
||||
return IMPL_LUMINA
|
||||
elif "anima" in model_config:
|
||||
return IMPL_ANIMA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
return IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -325,7 +336,7 @@ def determine_resolution(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# Determine default resolution based on model type
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
|
||||
reso = (1024, 1024)
|
||||
elif v2 and v_parameterization:
|
||||
reso = (768, 768)
|
||||
|
||||
@@ -3581,6 +3581,7 @@ def get_sai_model_spec_dataclass(
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
hunyuan_image: str = None,
|
||||
anima: str = None,
|
||||
optional_metadata: dict[str, str] | None = None,
|
||||
) -> sai_model_spec.ModelSpecMetadata:
|
||||
"""
|
||||
@@ -3612,7 +3613,8 @@ def get_sai_model_spec_dataclass(
|
||||
model_config["lumina"] = lumina
|
||||
if hunyuan_image is not None:
|
||||
model_config["hunyuan_image"] = hunyuan_image
|
||||
|
||||
if anima is not None:
|
||||
model_config["anima"] = anima
|
||||
# Use the dataclass function directly
|
||||
return sai_model_spec.build_metadata_dataclass(
|
||||
state_dict,
|
||||
|
||||
Reference in New Issue
Block a user