mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
fix: update argument names in anima_train_utils to align with other archtectures
This commit is contained in:
@@ -859,11 +859,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--blockwise_fused_optimizers",
|
||||
# action="store_true",
|
||||
# help="enable blockwise optimizers for fused backward pass and optimizer step",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
|
||||
@@ -15,6 +15,7 @@ from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -25,10 +26,6 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import anima_models, anima_utils, strategy_base, train_util
|
||||
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
|
||||
|
||||
|
||||
# Anima-specific training arguments
|
||||
|
||||
@@ -36,19 +33,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
|
||||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add Anima-specific training arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Anima DiT model safetensors file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to WanVAE safetensors/pth file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen3_path",
|
||||
"--qwen3",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||||
@@ -87,7 +72,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"--mod_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_tokenizer_path",
|
||||
@@ -114,34 +99,29 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_sample_method",
|
||||
"--timestep_sampling",
|
||||
type=str,
|
||||
default="logit_normal",
|
||||
choices=["logit_normal", "uniform"],
|
||||
help="Timestep sampling method (default: logit_normal)",
|
||||
default="sigmoid",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
help="Timestep sampling method (default: sigmoid (logit normal))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
|
||||
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
|
||||
)
|
||||
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
|
||||
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
|
||||
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
|
||||
# is zeroed out in the training scripts to allow caching.
|
||||
parser.add_argument(
|
||||
"--transformer_dtype",
|
||||
type=str,
|
||||
"--attn_mode",
|
||||
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||||
default=None,
|
||||
choices=["float16", "bfloat16", "float32", None],
|
||||
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
|
||||
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||||
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash_attn",
|
||||
"--split_attn",
|
||||
action="store_true",
|
||||
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
|
||||
"Falls back to PyTorch SDPA if flash-attn is not installed.",
|
||||
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||||
)
|
||||
|
||||
|
||||
@@ -218,7 +198,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss weighting for Anima training.
|
||||
|
||||
Same schemes as SD3 but can add Anima-specific ones.
|
||||
Same schemes as SD3 but can add Anima-specific ones if needed in future.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
@@ -382,6 +362,7 @@ def do_sample(
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
guidance_scale: float = 1.0,
|
||||
flow_shift: float = 3.0,
|
||||
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||||
@@ -395,6 +376,7 @@ def do_sample(
|
||||
dtype: Compute dtype
|
||||
device: Compute device
|
||||
guidance_scale: CFG scale (1.0 = no guidance)
|
||||
flow_shift: Flow shift parameter for rectified flow
|
||||
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||||
|
||||
Returns:
|
||||
@@ -414,6 +396,9 @@ def do_sample(
|
||||
|
||||
# Timestep schedule: linear from 1.0 to 0.0
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||
flow_shift = float(flow_shift)
|
||||
if flow_shift != 1.0:
|
||||
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
|
||||
|
||||
# Start from pure noise
|
||||
x = noise.clone()
|
||||
@@ -461,7 +446,6 @@ def sample_images(
|
||||
steps,
|
||||
dit,
|
||||
vae,
|
||||
vae_scale,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
@@ -515,7 +499,6 @@ def sample_images(
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
vae_scale,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
@@ -539,8 +522,7 @@ def _sample_image_inference(
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
vae_scale,
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
@@ -558,6 +540,7 @@ def _sample_image_inference(
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
flow_shift = prompt_dict.get("flow_shift", 3.0)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
@@ -571,7 +554,9 @@ def _sample_image_inference(
|
||||
height = max(64, height - height % 16)
|
||||
width = max(64, width - width % 16)
|
||||
|
||||
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
|
||||
logger.info(
|
||||
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
|
||||
)
|
||||
|
||||
# Encode prompt
|
||||
def encode_prompt(prpt):
|
||||
@@ -597,14 +582,14 @@ def _sample_image_inference(
|
||||
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||||
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
if dit.net.use_llm_adapter:
|
||||
crossattn_emb = dit.net.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
@@ -626,13 +611,13 @@ def _sample_image_inference(
|
||||
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||||
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||||
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
|
||||
neg_am = neg_am.to(accelerator.device)
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
if dit.net.use_llm_adapter:
|
||||
neg_crossattn_emb = dit.net.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
target_attention_mask=neg_t5_am,
|
||||
@@ -645,23 +630,14 @@ def _sample_image_inference(
|
||||
# Generate sample
|
||||
clean_memory_on_device(accelerator.device)
|
||||
latents = do_sample(
|
||||
height,
|
||||
width,
|
||||
seed,
|
||||
dit,
|
||||
crossattn_emb,
|
||||
sample_steps,
|
||||
dit.t_embedding_norm.weight.dtype,
|
||||
accelerator.device,
|
||||
scale,
|
||||
neg_crossattn_emb,
|
||||
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
|
||||
)
|
||||
|
||||
# Decode latents
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = next(vae.parameters()).device
|
||||
org_vae_device = vae.device
|
||||
vae.to(accelerator.device)
|
||||
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
|
||||
decoded = vae.decode_to_pixels(latents)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user