diff --git a/anima_train.py b/anima_train.py index ae3cf6a0..13c15f0c 100644 --- a/anima_train.py +++ b/anima_train.py @@ -3,6 +3,7 @@ import argparse from concurrent.futures import ThreadPoolExecutor import copy +import gc import math import os from multiprocessing import Value @@ -12,8 +13,9 @@ import toml from tqdm import tqdm import torch -from library import utils +from library import flux_train_utils, qwen_image_autoencoder_kl, utils from library.device_utils import init_ipex, clean_memory_on_device +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler init_ipex() @@ -56,7 +58,7 @@ def train(args): logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - if getattr(args, "unsloth_offload_checkpointing", False): + if args.unsloth_offload_checkpointing: if not args.gradient_checkpointing: logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True @@ -66,19 +68,19 @@ def train(args): args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" - assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr( - args, "unsloth_offload_checkpointing", False - ), "blocks_to_swap is not supported with unsloth_offload_checkpointing" + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing" - # Flash attention: validate availability - if getattr(args, "flash_attn", False): - try: - import flash_attn # noqa: F401 + # # Flash attention: validate availability + # if args.flash_attn: + # try: + # import flash_attn # noqa: F401 - logger.info("Flash Attention enabled for DiT blocks") - except ImportError: - logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") - args.flash_attn = False + # logger.info("Flash Attention enabled for DiT blocks") + # except ImportError: + # logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") + # args.flash_attn = False cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -140,26 +142,13 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8 - - # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of - # dataset-level caption dropout, so we save the rate and zero out subset-level - # caption_dropout_rate to allow text encoder output caching. - caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) - if caption_dropout_rate > 0: - logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}") - for dataset in train_dataset_group.datasets: - for subset in dataset.subsets: - subset.caption_dropout_rate = 0.0 + train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2 if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_anima.AnimaTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - False, - False, + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) train_dataset_group.set_current_strategies() @@ -173,8 +162,8 @@ def train(args): assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used" if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() + assert train_dataset_group.is_text_encoder_output_cacheable( + cache_supports_dropout=True ), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used" # prepare accelerator @@ -184,20 +173,10 @@ def train(args): # mixed precision dtype weight_dtype, save_dtype = train_util.prepare_dtype(args) - # parse transformer_dtype - transformer_dtype = None - if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None: - transformer_dtype_map = { - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "float32": torch.float32, - } - transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None) - # Load tokenizers and set strategies logger.info("Loading tokenizers...") - qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu") - t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None)) + qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu") + t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path) # Set tokenize strategy tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( @@ -208,11 +187,7 @@ def train(args): ) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) - # Set text encoding strategy - caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) - text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( - dropout_rate=caption_dropout_rate, - ) + text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy() strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) # Prepare text encoder (always frozen for Anima) @@ -226,10 +201,7 @@ def train(args): qwen3_text_encoder.eval() text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, - args.text_encoder_batch_size, - args.skip_cache_check, - is_partial=False, + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) @@ -248,25 +220,19 @@ def train(args): logger.info(f" cache TE outputs for: {p}") tokens_and_masks = tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( - tokenize_strategy, - [qwen3_text_encoder], - tokens_and_masks, - enable_dropout=False, + tokenize_strategy, [qwen3_text_encoder], tokens_and_masks ) - # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted - with accelerator.autocast(): - text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder]) - accelerator.wait_for_everyone() # free text encoder memory qwen3_text_encoder = None + gc.collect() # Force garbage collection to free memory clean_memory_on_device(accelerator.device) # Load VAE and cache latents logger.info("Loading Anima VAE...") - vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu") + vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu") if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -281,24 +247,16 @@ def train(args): # Load DiT (MiniTrainDIT + optional LLM Adapter) logger.info("Loading Anima DiT...") - dit = anima_utils.load_anima_dit( - args.dit_path, - dtype=weight_dtype, - device="cpu", - transformer_dtype=transformer_dtype, - llm_adapter_path=getattr(args, "llm_adapter_path", None), - disable_mmap=getattr(args, "disable_mmap_load_safetensors", False), + dit = anima_utils.load_anima_model( + "cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None ) if args.gradient_checkpointing: dit.enable_gradient_checkpointing( cpu_offload=args.cpu_offload_checkpointing, - unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False), + unsloth_offload=args.unsloth_offload_checkpointing, ) - if getattr(args, "flash_attn", False): - dit.set_flash_attn(True) - train_dit = args.learning_rate != 0 dit.requires_grad_(train_dit) if not train_dit: @@ -314,19 +272,17 @@ def train(args): vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=weight_dtype) - # Move scale tensors to same device as VAE for on-the-fly encoding - vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale] # Setup optimizer with parameter groups if train_dit: param_groups = anima_train_utils.get_anima_param_groups( dit, base_lr=args.learning_rate, - self_attn_lr=getattr(args, "self_attn_lr", None), - cross_attn_lr=getattr(args, "cross_attn_lr", None), - mlp_lr=getattr(args, "mlp_lr", None), - mod_lr=getattr(args, "mod_lr", None), - llm_adapter_lr=getattr(args, "llm_adapter_lr", None), + self_attn_lr=args.self_attn_lr, + cross_attn_lr=args.cross_attn_lr, + mlp_lr=args.mlp_lr, + mod_lr=args.mod_lr, + llm_adapter_lr=args.llm_adapter_lr, ) else: param_groups = [] @@ -348,57 +304,7 @@ def train(args): # prepare optimizer accelerator.print("prepare optimizer, data loader etc.") - if args.blockwise_fused_optimizers: - # Split params into per-block groups for blockwise fused optimizer - # Build param_id → lr mapping from param_groups to propagate per-component LRs - param_lr_map = {} - for group in param_groups: - for p in group["params"]: - param_lr_map[id(p)] = group["lr"] - - grouped_params = [] - param_group = {} - named_parameters = list(dit.named_parameters()) - for name, p in named_parameters: - if not p.requires_grad: - continue - # Determine block type and index - if name.startswith("blocks."): - block_index = int(name.split(".")[1]) - block_type = "blocks" - elif name.startswith("llm_adapter.blocks."): - block_index = int(name.split(".")[2]) - block_type = "llm_adapter" - else: - block_index = -1 - block_type = "other" - - param_group_key = (block_type, block_index) - if param_group_key not in param_group: - param_group[param_group_key] = [] - param_group[param_group_key].append(p) - - for param_group_key, params in param_group.items(): - # Use per-component LR from param_groups if available - lr = param_lr_map.get(id(params[0]), args.learning_rate) - grouped_params.append({"params": params, "lr": lr}) - num_params = sum(p.numel() for p in params) - accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}") - - # Create per-group optimizers - optimizers = [] - for group in grouped_params: - _, _, opt = train_util.get_optimizer(args, trainable_params=[group]) - optimizers.append(opt) - optimizer = optimizers[0] # avoid error in following code - - logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") - - if train_util.is_schedulefree_optimizer(optimizers[0], args): - raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") - optimizer_train_fn = lambda: None - optimizer_eval_fn = lambda: None - elif args.fused_backward_pass: + if args.fused_backward_pass: # Pass per-component param_groups directly to preserve per-component LRs _, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) @@ -429,21 +335,19 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr scheduler - if args.blockwise_fused_optimizers: - lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers] - lr_scheduler = lr_schedulers[0] # avoid error in following code - else: - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # full fp16/bf16 training + dit_weight_dtype = weight_dtype if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'" accelerator.print("enable full fp16 training.") - dit.to(weight_dtype) elif args.full_bf16: assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'" accelerator.print("enable full bf16 training.") - dit.to(weight_dtype) + else: + dit_weight_dtype = torch.float32 # Default to float32 + dit.to(dit_weight_dtype) # convert dit to target weight dtype # move text encoder to GPU if not cached if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: @@ -485,6 +389,7 @@ def train(args): train_util.resume_from_local_or_hf_if_specified(accelerator, args) if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) @@ -504,53 +409,28 @@ def train(args): parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group)) - elif args.blockwise_fused_optimizers: - # Prepare additional optimizers and lr schedulers - for i in range(1, len(optimizers)): - optimizers[i] = accelerator.prepare(optimizers[i]) - lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) - - # Counters for blockwise gradient hook - optimizer_hooked_count = {} - num_parameters_per_group = [0] * len(optimizers) - parameter_optimizer_map = {} - - for opt_idx, opt in enumerate(optimizers): - for param_group in opt.param_groups: - for parameter in param_group["params"]: - if parameter.requires_grad: - - def grad_hook(parameter: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(parameter, args.max_grad_norm) - - i = parameter_optimizer_map[parameter] - optimizer_hooked_count[i] += 1 - if optimizer_hooked_count[i] == num_parameters_per_group[i]: - optimizers[i].step() - optimizers[i].zero_grad(set_to_none=True) - - parameter.register_post_accumulate_grad_hook(grad_hook) - parameter_optimizer_map[parameter] = opt_idx - num_parameters_per_group[opt_idx] += 1 - # Training loop num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - accelerator.print("running training") - accelerator.print(f" num examples: {train_dataset_group.num_train_images}") - accelerator.print(f" num batches per epoch: {len(train_dataloader)}") - accelerator.print(f" num epochs: {num_train_epochs}") - accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: @@ -581,7 +461,6 @@ def train(args): global_step, dit, vae, - vae_scale, qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, @@ -594,13 +473,11 @@ def train(args): # Show model info unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None if unwrapped_dit is not None: - logger.info( - f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}" - ) + logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}") if qwen3_text_encoder is not None: - logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}") + logger.info(f"qwen3 device: {qwen3_text_encoder.device}") if vae is not None: - logger.info(f"vae device: {next(vae.parameters()).device}") + logger.info(f"vae device: {vae.device}") loss_recorder = train_util.LossRecorder() epoch = 0 @@ -614,19 +491,17 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step - if args.blockwise_fused_optimizers: - optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - with accelerator.accumulate(*training_models): # Get latents if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype) + if latents.ndim == 5: # Fallback for 5D latents (old cache) + latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W) else: with torch.no_grad(): # images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim images = batch["images"].to(accelerator.device, dtype=weight_dtype) - images = images.unsqueeze(2) # (B, C, 1, H, W) - latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype) + latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype) if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") @@ -636,21 +511,24 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Cached outputs - text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + caption_dropout_rates = text_encoder_outputs_list[-1] + text_encoder_outputs_list = text_encoder_outputs_list[:-1] + + # Apply caption dropout to cached outputs + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs( + *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates + ) prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list else: # Encode on-the-fly input_ids_list = batch["input_ids_list"] - qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list with torch.no_grad(): prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens( - tokenize_strategy, - [qwen3_text_encoder], - [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask], + tokenize_strategy, [qwen3_text_encoder], input_ids_list ) # Move to device - prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype) + prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype) attn_mask = attn_mask.to(accelerator.device) t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_attn_mask = t5_attn_mask.to(accelerator.device) @@ -658,9 +536,11 @@ def train(args): # Noise and timesteps noise = torch.randn_like(latents) - noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps( - args, latents, noise, accelerator.device, weight_dtype + # Get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype ) + timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32 # NaN checks if torch.any(torch.isnan(noisy_model_input)): @@ -672,12 +552,10 @@ def train(args): bs = latents.shape[0] h_latent = latents.shape[-2] w_latent = latents.shape[-1] - padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device) + padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device) # DiT forward (LLM adapter runs inside forward for DDP gradient sync) - if is_swapping_blocks: - accelerator.unwrap_model(dit).prepare_block_swap_before_forward() - + noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W) with accelerator.autocast(): model_pred = dit( noisy_model_input, @@ -688,6 +566,7 @@ def train(args): t5_input_ids=t5_input_ids, t5_attn_mask=t5_attn_mask, ) + model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W) # Compute loss (rectified flow: target = noise - latents) target = noise - latents @@ -702,7 +581,7 @@ def train(args): loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,) + loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,) if weighting is not None: loss = loss * weighting @@ -713,7 +592,7 @@ def train(args): accelerator.backward(loss) - if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if not args.fused_backward_pass: if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -726,9 +605,6 @@ def train(args): else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() - if args.blockwise_fused_optimizers: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step if accelerator.sync_gradients: @@ -743,7 +619,6 @@ def train(args): global_step, dit, vae, - vae_scale, qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, @@ -812,7 +687,6 @@ def train(args): global_step, dit, vae, - vae_scale, qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, @@ -859,11 +733,6 @@ def setup_parser() -> argparse.ArgumentParser: anima_train_utils.add_anima_training_arguments(parser) sai_model_spec.add_model_spec_arguments(parser) - # parser.add_argument( - # "--blockwise_fused_optimizers", - # action="store_true", - # help="enable blockwise optimizers for fused backward pass and optimizer step", - # ) parser.add_argument( "--cpu_offload_checkpointing", action="store_true", @@ -891,4 +760,7 @@ if __name__ == "__main__": train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) + if args.attn_mode == "sdpa": + args.attn_mode = "torch" # backward compatibility + train(args) diff --git a/library/attention.py b/library/attention.py index d3b8441e..4f6a5422 100644 --- a/library/attention.py +++ b/library/attention.py @@ -37,6 +37,14 @@ class AttentionParams: cu_seqlens: Optional[torch.Tensor] = None max_seqlen: Optional[int] = None + @property + def supports_fp32(self) -> bool: + return self.attn_mode not in ["flash"] + + @property + def requires_same_dtype(self) -> bool: + return self.attn_mode in ["xformers"] + @staticmethod def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams": return AttentionParams(attn_mode, split_attn) @@ -95,7 +103,7 @@ def attention( qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors. k: Key tensor [B, L, H, D]. v: Value tensor [B, L, H, D]. - attn_param: Attention parameters including mask and sequence lengths. + attn_params: Attention parameters including mask and sequence lengths. drop_rate: Attention dropout rate. Returns: