mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Support Anima model (#2260)
* Support Anima model * Update document and fix bug * Fix latent normlization * Fix typo * Fix cache embedding * fix typo in tests/test_anima_cache.py * Remove redundant argument apply_t5_attn_mask * Improving caching with argument caption_dropout_rate * Fix W&B logging bugs * Fix discrete_flow_shift default value
This commit is contained in:
665
library/anima_train_utils.py
Normal file
665
library/anima_train_utils.py
Normal file
@@ -0,0 +1,665 @@
|
||||
# Anima Training Utilities
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from safetensors.torch import save_file
|
||||
from accelerate import Accelerator, PartialState
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import anima_models, anima_utils, strategy_base, train_util
|
||||
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
|
||||
|
||||
|
||||
# Anima-specific training arguments
|
||||
|
||||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add Anima-specific training arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Anima DiT model safetensors file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to WanVAE safetensors/pth file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen3_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm_adapter_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--self_attn_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cross_attn_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlp_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mod_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_tokenizer_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen3_max_token_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum token length for Qwen3 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_max_token_length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum token length for T5 tokenizer (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_sample_method",
|
||||
type=str,
|
||||
default="logit_normal",
|
||||
choices=["logit_normal", "uniform"],
|
||||
help="Timestep sampling method (default: logit_normal)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
|
||||
)
|
||||
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
|
||||
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
|
||||
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
|
||||
# is zeroed out in the training scripts to allow caching.
|
||||
parser.add_argument(
|
||||
"--transformer_dtype",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["float16", "bfloat16", "float32", None],
|
||||
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash_attn",
|
||||
action="store_true",
|
||||
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
|
||||
"Falls back to PyTorch SDPA if flash-attn is not installed.",
|
||||
)
|
||||
|
||||
|
||||
# Noise & Timestep sampling (Rectified Flow)
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args,
|
||||
latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate noisy model input and timesteps for rectified flow training.
|
||||
|
||||
Rectified flow: noisy_input = (1 - t) * latents + t * noise
|
||||
Target: noise - latents
|
||||
|
||||
Args:
|
||||
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
|
||||
latents: Clean latent tensors
|
||||
noise: Random noise tensors
|
||||
device: Target device
|
||||
dtype: Target dtype
|
||||
|
||||
Returns:
|
||||
(noisy_model_input, timesteps, sigmas)
|
||||
"""
|
||||
bs = latents.shape[0]
|
||||
|
||||
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
|
||||
shift = getattr(args, 'discrete_flow_shift', 1.0)
|
||||
|
||||
if timestep_sample_method == 'logit_normal':
|
||||
dist = torch.distributions.normal.Normal(0, 1)
|
||||
elif timestep_sample_method == 'uniform':
|
||||
dist = torch.distributions.uniform.Uniform(0, 1)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
|
||||
|
||||
t = dist.sample((bs,)).to(device)
|
||||
|
||||
if timestep_sample_method == 'logit_normal':
|
||||
t = t * sigmoid_scale
|
||||
t = torch.sigmoid(t)
|
||||
|
||||
# Apply shift
|
||||
if shift is not None and shift != 1.0:
|
||||
t = (t * shift) / (1 + (shift - 1) * t)
|
||||
|
||||
# Clamp to avoid exact 0 or 1
|
||||
t = t.clamp(1e-5, 1.0 - 1e-5)
|
||||
|
||||
# Create noisy input: (1 - t) * latents + t * noise
|
||||
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
|
||||
|
||||
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
|
||||
if ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if getattr(args, 'ip_noise_gamma_random_strength', False):
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
|
||||
|
||||
# Sigmas for potential loss weighting
|
||||
sigmas = t.view(-1, 1)
|
||||
|
||||
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
|
||||
|
||||
|
||||
# Loss weighting
|
||||
|
||||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss weighting for Anima training.
|
||||
|
||||
Same schemes as SD3 but can add Anima-specific ones.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
elif weighting_scheme == "none" or weighting_scheme is None:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
# Parameter groups (6 groups with separate LRs)
|
||||
def get_anima_param_groups(
|
||||
dit,
|
||||
base_lr: float,
|
||||
self_attn_lr: Optional[float] = None,
|
||||
cross_attn_lr: Optional[float] = None,
|
||||
mlp_lr: Optional[float] = None,
|
||||
mod_lr: Optional[float] = None,
|
||||
llm_adapter_lr: Optional[float] = None,
|
||||
):
|
||||
"""Create parameter groups for Anima training with separate learning rates.
|
||||
|
||||
Args:
|
||||
dit: MiniTrainDIT model
|
||||
base_lr: Base learning rate
|
||||
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||||
cross_attn_lr: LR for cross-attention layers
|
||||
mlp_lr: LR for MLP layers
|
||||
mod_lr: LR for AdaLN modulation layers
|
||||
llm_adapter_lr: LR for LLM adapter
|
||||
|
||||
Returns:
|
||||
List of parameter group dicts for optimizer
|
||||
"""
|
||||
if self_attn_lr is None:
|
||||
self_attn_lr = base_lr
|
||||
if cross_attn_lr is None:
|
||||
cross_attn_lr = base_lr
|
||||
if mlp_lr is None:
|
||||
mlp_lr = base_lr
|
||||
if mod_lr is None:
|
||||
mod_lr = base_lr
|
||||
if llm_adapter_lr is None:
|
||||
llm_adapter_lr = base_lr
|
||||
|
||||
base_params = []
|
||||
self_attn_params = []
|
||||
cross_attn_params = []
|
||||
mlp_params = []
|
||||
mod_params = []
|
||||
llm_adapter_params = []
|
||||
|
||||
for name, p in dit.named_parameters():
|
||||
# Store original name for debugging
|
||||
p.original_name = name
|
||||
|
||||
if 'llm_adapter' in name:
|
||||
llm_adapter_params.append(p)
|
||||
elif '.self_attn' in name:
|
||||
self_attn_params.append(p)
|
||||
elif '.cross_attn' in name:
|
||||
cross_attn_params.append(p)
|
||||
elif '.mlp' in name:
|
||||
mlp_params.append(p)
|
||||
elif '.adaln_modulation' in name:
|
||||
mod_params.append(p)
|
||||
else:
|
||||
base_params.append(p)
|
||||
|
||||
logger.info(f"Parameter groups:")
|
||||
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
|
||||
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
|
||||
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
|
||||
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
|
||||
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
|
||||
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
|
||||
|
||||
param_groups = []
|
||||
for lr, params, name in [
|
||||
(base_lr, base_params, "base"),
|
||||
(self_attn_lr, self_attn_params, "self_attn"),
|
||||
(cross_attn_lr, cross_attn_params, "cross_attn"),
|
||||
(mlp_lr, mlp_params, "mlp"),
|
||||
(mod_lr, mod_params, "mod"),
|
||||
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
|
||||
]:
|
||||
if lr == 0:
|
||||
for p in params:
|
||||
p.requires_grad_(False)
|
||||
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||||
elif len(params) > 0:
|
||||
param_groups.append({'params': params, 'lr': lr})
|
||||
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
|
||||
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||||
|
||||
return param_groups
|
||||
|
||||
|
||||
# Save functions
|
||||
def save_anima_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
):
|
||||
"""Save Anima model at the end of training."""
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||
)
|
||||
dit_sd = dit.state_dict()
|
||||
# Save with 'net.' prefix for ComfyUI compatibility
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
|
||||
def save_anima_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator: Accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
):
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||
)
|
||||
dit_sd = dit.state_dict()
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
True,
|
||||
True,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Sampling (Euler discrete for rectified flow)
|
||||
def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: Optional[int],
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
crossattn_emb: torch.Tensor,
|
||||
steps: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
guidance_scale: float = 1.0,
|
||||
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||||
|
||||
Args:
|
||||
height, width: Output image dimensions
|
||||
seed: Random seed (None for random)
|
||||
dit: MiniTrainDIT model
|
||||
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||||
steps: Number of sampling steps
|
||||
dtype: Compute dtype
|
||||
device: Compute device
|
||||
guidance_scale: CFG scale (1.0 = no guidance)
|
||||
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||||
|
||||
Returns:
|
||||
Denoised latents
|
||||
"""
|
||||
# Latent shape: (1, 16, 1, H/8, W/8) for single image
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
|
||||
|
||||
# Generate noise
|
||||
if seed is not None:
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
noise = torch.randn(
|
||||
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
|
||||
).to(dtype).to(device)
|
||||
|
||||
# Timestep schedule: linear from 1.0 to 0.0
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||
|
||||
# Start from pure noise
|
||||
x = noise.clone()
|
||||
|
||||
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
|
||||
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
|
||||
|
||||
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
|
||||
|
||||
for i in tqdm(range(steps), desc="Sampling"):
|
||||
sigma = sigmas[i]
|
||||
t = sigma.unsqueeze(0) # (1,)
|
||||
|
||||
dit.prepare_block_swap_before_forward()
|
||||
|
||||
if use_cfg:
|
||||
# CFG: concat positive and negative
|
||||
x_input = torch.cat([x, x], dim=0)
|
||||
t_input = torch.cat([t, t], dim=0)
|
||||
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
|
||||
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
|
||||
|
||||
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
|
||||
model_output = model_output.float()
|
||||
|
||||
pos_out, neg_out = model_output.chunk(2)
|
||||
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||||
else:
|
||||
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
model_output = model_output.float()
|
||||
|
||||
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
|
||||
dt = sigmas[i + 1] - sigma
|
||||
x = x + model_output * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
dit.prepare_block_swap_before_forward()
|
||||
return x
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit,
|
||||
vae,
|
||||
vae_scale,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs=None,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
"""Generate sample images during training.
|
||||
|
||||
This is a simplified sampler for Anima - it generates images using the current model state.
|
||||
"""
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None:
|
||||
return
|
||||
|
||||
logger.info(f"Generating sample images at step {steps}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
# Unwrap models
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
save_dir = os.path.join(args.output_dir, "sample")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Save RNG state
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
_sample_image_inference(
|
||||
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||
tokenize_strategy, text_encoding_strategy,
|
||||
save_dir, prompt_dict, epoch, steps,
|
||||
sample_prompts_te_outputs, prompt_replacement,
|
||||
)
|
||||
|
||||
# Restore RNG state
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def _sample_image_inference(
|
||||
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||
tokenize_strategy, text_encoding_strategy,
|
||||
save_dir, prompt_dict, epoch, steps,
|
||||
sample_prompts_te_outputs, prompt_replacement,
|
||||
):
|
||||
"""Generate a single sample image."""
|
||||
prompt = prompt_dict.get("prompt", "")
|
||||
negative_prompt = prompt_dict.get("negative_prompt", "")
|
||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
|
||||
|
||||
height = max(64, height - height % 16)
|
||||
width = max(64, width - width % 16)
|
||||
|
||||
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
|
||||
|
||||
# Encode prompt
|
||||
def encode_prompt(prpt):
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
return sample_prompts_te_outputs[prpt]
|
||||
if text_encoder is not None:
|
||||
tokens = tokenize_strategy.tokenize(prpt)
|
||||
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||||
return encoded
|
||||
return None
|
||||
|
||||
encoded = encode_prompt(prompt)
|
||||
if encoded is None:
|
||||
logger.warning("Cannot encode prompt, skipping sample")
|
||||
return
|
||||
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
|
||||
|
||||
# Convert to tensors if numpy
|
||||
if isinstance(prompt_embeds, np.ndarray):
|
||||
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
|
||||
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
|
||||
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||||
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
)
|
||||
crossattn_emb[~t5_attn_mask.bool()] = 0
|
||||
else:
|
||||
crossattn_emb = prompt_embeds
|
||||
|
||||
# Encode negative prompt for CFG
|
||||
neg_crossattn_emb = None
|
||||
if scale > 1.0 and negative_prompt is not None:
|
||||
neg_encoded = encode_prompt(negative_prompt)
|
||||
if neg_encoded is not None:
|
||||
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
|
||||
if isinstance(neg_pe, np.ndarray):
|
||||
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
|
||||
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
|
||||
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||||
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||||
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||
neg_am = neg_am.to(accelerator.device)
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
target_attention_mask=neg_t5_am,
|
||||
source_attention_mask=neg_am,
|
||||
)
|
||||
neg_crossattn_emb[~neg_t5_am.bool()] = 0
|
||||
else:
|
||||
neg_crossattn_emb = neg_pe
|
||||
|
||||
# Generate sample
|
||||
clean_memory_on_device(accelerator.device)
|
||||
latents = do_sample(
|
||||
height, width, seed, dit, crossattn_emb,
|
||||
sample_steps, dit.t_embedding_norm.weight.dtype,
|
||||
accelerator.device, scale, neg_crossattn_emb,
|
||||
)
|
||||
|
||||
# Decode latents
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = next(vae.parameters()).device
|
||||
vae.to(accelerator.device)
|
||||
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Convert to image
|
||||
image = decoded.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
# Remove temporal dim if present
|
||||
if image.ndim == 4:
|
||||
image = image[:, 0, :, :]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
decoded_np = decoded_np.astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(decoded_np)
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i = prompt_dict.get("enum", 0)
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# Log to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
import wandb
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|
||||
Reference in New Issue
Block a user