format: format

This commit is contained in:
kohya-ss
2026-02-08 12:22:54 +09:00
parent 648c045cb0
commit d992037984
7 changed files with 536 additions and 454 deletions

View File

@@ -32,6 +32,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
@@ -169,20 +170,20 @@ def get_noisy_model_input_and_timesteps(
"""
bs = latents.shape[0]
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
shift = getattr(args, 'discrete_flow_shift', 1.0)
timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
shift = getattr(args, "discrete_flow_shift", 1.0)
if timestep_sample_method == 'logit_normal':
if timestep_sample_method == "logit_normal":
dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == 'uniform':
elif timestep_sample_method == "uniform":
dist = torch.distributions.uniform.Uniform(0, 1)
else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device)
if timestep_sample_method == 'logit_normal':
if timestep_sample_method == "logit_normal":
t = t * sigmoid_scale
t = torch.sigmoid(t)
@@ -196,10 +197,10 @@ def get_noisy_model_input_and_timesteps(
# Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, 'ip_noise_gamma_random_strength', False):
if getattr(args, "ip_noise_gamma_random_strength", False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else:
@@ -213,6 +214,7 @@ def get_noisy_model_input_and_timesteps(
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
@@ -276,15 +278,15 @@ def get_anima_param_groups(
# Store original name for debugging
p.original_name = name
if 'llm_adapter' in name:
if "llm_adapter" in name:
llm_adapter_params.append(p)
elif '.self_attn' in name:
elif ".self_attn" in name:
self_attn_params.append(p)
elif '.cross_attn' in name:
elif ".cross_attn" in name:
cross_attn_params.append(p)
elif '.mlp' in name:
elif ".mlp" in name:
mlp_params.append(p)
elif '.adaln_modulation' in name:
elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
@@ -311,9 +313,9 @@ def get_anima_param_groups(
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({'params': params, 'lr': lr})
param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
@@ -328,10 +330,9 @@ def save_anima_model_on_train_end(
dit: anima_models.MiniTrainDIT,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -350,10 +351,9 @@ def save_anima_model_on_epoch_end_or_stepwise(
dit: anima_models.MiniTrainDIT,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -410,9 +410,7 @@ def do_sample(
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
).to(dtype).to(device)
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
@@ -512,10 +510,20 @@ def sample_images(
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
_sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae,
vae_scale,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
# Restore RNG state
@@ -527,10 +535,20 @@ def sample_images(
def _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae,
vae_scale,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
@@ -585,7 +603,7 @@ def _sample_image_inference(
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
@@ -613,7 +631,7 @@ def _sample_image_inference(
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
@@ -627,9 +645,16 @@ def _sample_image_inference(
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb,
sample_steps, dit.t_embedding_norm.weight.dtype,
accelerator.device, scale, neg_crossattn_emb,
height,
width,
seed,
dit,
crossattn_emb,
sample_steps,
dit.t_embedding_norm.weight.dtype,
accelerator.device,
scale,
neg_crossattn_emb,
)
# Decode latents
@@ -662,4 +687,5 @@ def _sample_image_inference(
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)