feat: refactor Anima training script

This commit is contained in:
Kohya S
2026-02-10 21:27:22 +09:00
parent 6d08c93b23
commit 8d2d286a13
2 changed files with 90 additions and 210 deletions

View File

@@ -3,6 +3,7 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import gc
import math
import os
from multiprocessing import Value
@@ -12,8 +13,9 @@ import toml
from tqdm import tqdm
import torch
from library import utils
from library import flux_train_utils, qwen_image_autoencoder_kl, utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex()
@@ -56,7 +58,7 @@ def train(args):
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
if getattr(args, "unsloth_offload_checkpointing", False):
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
@@ -66,19 +68,19 @@ def train(args):
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr(
args, "unsloth_offload_checkpointing", False
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, "flash_attn", False):
try:
import flash_attn # noqa: F401
# # Flash attention: validate availability
# if args.flash_attn:
# try:
# import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
args.flash_attn = False
# logger.info("Flash Attention enabled for DiT blocks")
# except ImportError:
# logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
# args.flash_attn = False
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -140,26 +142,13 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so we save the rate and zero out subset-level
# caption_dropout_rate to allow text encoder output caching.
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
for dataset in train_dataset_group.datasets:
for subset in dataset.subsets:
subset.caption_dropout_rate = 0.0
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
train_dataset_group.set_current_strategies()
@@ -173,8 +162,8 @@ def train(args):
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
# prepare accelerator
@@ -184,20 +173,10 @@ def train(args):
# mixed precision dtype
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# parse transformer_dtype
transformer_dtype = None
if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu")
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@@ -208,11 +187,7 @@ def train(args):
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Set text encoding strategy
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# Prepare text encoder (always frozen for Anima)
@@ -226,10 +201,7 @@ def train(args):
qwen3_text_encoder.eval()
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
@@ -248,25 +220,19 @@ def train(args):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_text_encoder],
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
)
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone()
# free text encoder memory
qwen3_text_encoder = None
gc.collect() # Force garbage collection to free memory
clean_memory_on_device(accelerator.device)
# Load VAE and cache latents
logger.info("Loading Anima VAE...")
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu")
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
@@ -281,24 +247,16 @@ def train(args):
# Load DiT (MiniTrainDIT + optional LLM Adapter)
logger.info("Loading Anima DiT...")
dit = anima_utils.load_anima_dit(
args.dit_path,
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, "llm_adapter_path", None),
disable_mmap=getattr(args, "disable_mmap_load_safetensors", False),
dit = anima_utils.load_anima_model(
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False),
unsloth_offload=args.unsloth_offload_checkpointing,
)
if getattr(args, "flash_attn", False):
dit.set_flash_attn(True)
train_dit = args.learning_rate != 0
dit.requires_grad_(train_dit)
if not train_dit:
@@ -314,19 +272,17 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# Move scale tensors to same device as VAE for on-the-fly encoding
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
# Setup optimizer with parameter groups
if train_dit:
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
self_attn_lr=getattr(args, "self_attn_lr", None),
cross_attn_lr=getattr(args, "cross_attn_lr", None),
mlp_lr=getattr(args, "mlp_lr", None),
mod_lr=getattr(args, "mod_lr", None),
llm_adapter_lr=getattr(args, "llm_adapter_lr", None),
self_attn_lr=args.self_attn_lr,
cross_attn_lr=args.cross_attn_lr,
mlp_lr=args.mlp_lr,
mod_lr=args.mod_lr,
llm_adapter_lr=args.llm_adapter_lr,
)
else:
param_groups = []
@@ -348,57 +304,7 @@ def train(args):
# prepare optimizer
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# Split params into per-block groups for blockwise fused optimizer
# Build param_id → lr mapping from param_groups to propagate per-component LRs
param_lr_map = {}
for group in param_groups:
for p in group["params"]:
param_lr_map[id(p)] = group["lr"]
grouped_params = []
param_group = {}
named_parameters = list(dit.named_parameters())
for name, p in named_parameters:
if not p.requires_grad:
continue
# Determine block type and index
if name.startswith("blocks."):
block_index = int(name.split(".")[1])
block_type = "blocks"
elif name.startswith("llm_adapter.blocks."):
block_index = int(name.split(".")[2])
block_type = "llm_adapter"
else:
block_index = -1
block_type = "other"
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
for param_group_key, params in param_group.items():
# Use per-component LR from param_groups if available
lr = param_lr_map.get(id(params[0]), args.learning_rate)
grouped_params.append({"params": params, "lr": lr})
num_params = sum(p.numel() for p in params)
accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
# Create per-group optimizers
optimizers = []
for group in grouped_params:
_, _, opt = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(opt)
optimizer = optimizers[0] # avoid error in following code
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None
optimizer_eval_fn = lambda: None
elif args.fused_backward_pass:
if args.fused_backward_pass:
# Pass per-component param_groups directly to preserve per-component LRs
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
@@ -429,21 +335,19 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr scheduler
if args.blockwise_fused_optimizers:
lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# full fp16/bf16 training
dit_weight_dtype = weight_dtype
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
accelerator.print("enable full fp16 training.")
dit.to(weight_dtype)
elif args.full_bf16:
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
accelerator.print("enable full bf16 training.")
dit.to(weight_dtype)
else:
dit_weight_dtype = torch.float32 # Default to float32
dit.to(dit_weight_dtype) # convert dit to target weight dtype
# move text encoder to GPU if not cached
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
@@ -485,6 +389,7 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
@@ -504,53 +409,28 @@ def train(args):
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
elif args.blockwise_fused_optimizers:
# Prepare additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# Counters for blockwise gradient hook
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, opt in enumerate(optimizers):
for param_group in opt.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# Training loop
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
accelerator.print("running training")
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
accelerator.print(f" num epochs: {num_train_epochs}")
accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps: {args.max_train_steps}")
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
@@ -581,7 +461,6 @@ def train(args):
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
@@ -594,13 +473,11 @@ def train(args):
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
logger.info(
f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}"
)
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
if qwen3_text_encoder is not None:
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
if vae is not None:
logger.info(f"vae device: {next(vae.parameters()).device}")
logger.info(f"vae device: {vae.device}")
loss_recorder = train_util.LossRecorder()
epoch = 0
@@ -614,19 +491,17 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
# Get latents
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
else:
with torch.no_grad():
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
images = images.unsqueeze(2) # (B, C, 1, H, W)
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
@@ -636,21 +511,24 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
input_ids_list = batch["input_ids_list"]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_text_encoder],
[qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
tokenize_strategy, [qwen3_text_encoder], input_ids_list
)
# Move to device
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
@@ -658,9 +536,11 @@ def train(args):
# Noise and timesteps
noise = torch.randn_like(latents)
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# NaN checks
if torch.any(torch.isnan(noisy_model_input)):
@@ -672,12 +552,10 @@ def train(args):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
if is_swapping_blocks:
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
with accelerator.autocast():
model_pred = dit(
noisy_model_input,
@@ -688,6 +566,7 @@ def train(args):
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
# Compute loss (rectified flow: target = noise - latents)
target = noise - latents
@@ -702,7 +581,7 @@ def train(args):
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
if weighting is not None:
loss = loss * weighting
@@ -713,7 +592,7 @@ def train(args):
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
@@ -726,9 +605,6 @@ def train(args):
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step
if accelerator.sync_gradients:
@@ -743,7 +619,6 @@ def train(args):
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
@@ -812,7 +687,6 @@ def train(args):
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
@@ -859,11 +733,6 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
# parser.add_argument(
# "--blockwise_fused_optimizers",
# action="store_true",
# help="enable blockwise optimizers for fused backward pass and optimizer step",
# )
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
@@ -891,4 +760,7 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
train(args)

View File

@@ -37,6 +37,14 @@ class AttentionParams:
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@property
def supports_fp32(self) -> bool:
return self.attn_mode not in ["flash"]
@property
def requires_same_dtype(self) -> bool:
return self.attn_mode in ["xformers"]
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@@ -95,7 +103,7 @@ def attention(
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
attn_param: Attention parameters including mask and sequence lengths.
attn_params: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns: