diff --git a/anima_minimal_inference.py b/anima_minimal_inference.py index 67af2876..8ab86ca2 100644 --- a/anima_minimal_inference.py +++ b/anima_minimal_inference.py @@ -255,14 +255,6 @@ def load_dit_model( lora_weights_list=lora_weights_list, lora_multipliers=args.lora_multiplier, ) - # model = anima_utils.load_anima_dit( - # args.dit, - # dtype=loading_weight_dtype, - # device=loading_device, - # transformer_dtype=loading_weight_dtype, - # llm_adapter_path=None, # getattr(args, "llm_adapter_path", None), - # disable_mmap=False, # getattr(args, "disable_mmap_load_safetensors", False), - # ) if not args.fp8_scaled: # simple cast to dit_weight_dtype target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) @@ -311,9 +303,7 @@ def prepare_text_inputs( conds_cache = {} text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device if shared_models is not None: - tokenizer = shared_models.get("tokenizer") text_encoder = shared_models.get("text_encoder") - t5xxl_tokenizer = shared_models.get("t5xxl_tokenizer") if "conds_cache" in shared_models: # Use shared cache if available conds_cache = shared_models["conds_cache"] @@ -321,32 +311,19 @@ def prepare_text_inputs( # text_encoder is on device (batched inference) or CPU (interactive inference) else: # Load if not in shared_models text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder - # tokenizer, text_encoder = anima_text_encoder.load_qwen3( - # args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device, disable_mmap=True - # ) - # t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer() text_encoder, _ = anima_utils.load_qwen3_text_encoder( args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device ) text_encoder.eval() tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() - if tokenize_strategy is None: - tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( - qwen3_path=args.text_encoder, - t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None), - qwen3_max_length=512, # args.qwen3_max_token_length, - t5_max_length=512, # args.t5_max_token_length, - ) - strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) # Store references so load_target_model can reuse them - tokenizer = tokenize_strategy.qwen3_tokenizer - t5xxl_tokenizer = tokenize_strategy.t5_tokenizer + # Store original devices to move back later if they were shared. This does nothing if shared_models is None text_encoder_original_device = text_encoder.device if text_encoder else None # Ensure text_encoder is not None before proceeding - if not text_encoder or not tokenizer or not t5xxl_tokenizer: - raise ValueError("Text encoder or tokenizer is not loaded properly.") + if not text_encoder: + raise ValueError("Text encoder is not loaded properly.") # Define a function to move models to device if needed # This is to avoid moving models if not needed, especially in interactive mode @@ -372,14 +349,14 @@ def prepare_text_inputs( else: move_models_to_device_if_needed() - encoding_strategy = strategy_anima.AnimaTextEncodingStrategy() tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() with torch.no_grad(): # embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt) tokens = tokenize_strategy.tokenize(prompt) embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens) - crossattn_emb = anima.preprocess_text_embeds( + crossattn_emb = anima._preprocess_text_embeds( source_hidden_states=embed[0].to(anima.device), target_input_ids=embed[2].to(anima.device), target_attention_mask=embed[3].to(anima.device), @@ -402,7 +379,7 @@ def prepare_text_inputs( # negative_embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, negative_prompt) tokens = tokenize_strategy.tokenize(negative_prompt) negative_embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens) - crossattn_emb = anima.preprocess_text_embeds( + crossattn_emb = anima._preprocess_text_embeds( source_hidden_states=negative_embed[0].to(anima.device), target_input_ids=negative_embed[2].to(anima.device), target_attention_mask=negative_embed[3].to(anima.device), @@ -416,7 +393,7 @@ def prepare_text_inputs( if not (shared_models and "text_encoder" in shared_models): # if loaded locally # There is a bug text_encoder is not freed from GPU memory when text encoder is fp8 - del tokenizer, text_encoder, t5xxl_tokenizer + del text_encoder gc.collect() # This may force Text Encoder to be freed from GPU memory else: # if shared, move back to original device (likely CPU) if text_encoder: @@ -719,13 +696,8 @@ def load_shared_models(args: argparse.Namespace) -> Dict: shared_models = {} # Load text encoders to CPU text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder - tokenizer, text_encoder = anima_text_encoder.load_qwen3( - args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True - ) - t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer() - shared_models["tokenizer"] = tokenizer + text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu") shared_models["text_encoder"] = text_encoder - shared_models["t5xxl_tokenizer"] = t5xxl_tokenizer return shared_models @@ -766,10 +738,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> # Text Encoder loaded to CPU by load_text_encoder text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder - tokenizer_batch, text_encoder_batch = anima_text_encoder.load_qwen3( - args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True - ) - t5xxl_tokenizer_batch = anima_text_encoder.load_t5xxl_tokenizer() + text_encoder_batch, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu") # Text Encoder to device for this phase text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device @@ -780,9 +749,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...") temp_shared_models_txt = { - "tokenizer": tokenizer_batch, "text_encoder": text_encoder_batch, # on GPU if not text_encoder_cpu - "t5xxl_tokenizer": t5xxl_tokenizer_batch, "conds_cache": conds_cache_batch, } @@ -795,7 +762,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> all_precomputed_text_data.append(text_data) # Models should be removed from device after prepare_text_inputs - del tokenizer_batch, text_encoder_batch, t5xxl_tokenizer_batch, temp_shared_models_txt, conds_cache_batch + del text_encoder_batch, temp_shared_models_txt, conds_cache_batch gc.collect() # Force cleanup of Text Encoder from GPU memory clean_memory_on_device(device) @@ -1001,41 +968,50 @@ def main(): vae.eval() save_output(args, vae, latent, device, original_base_names[i]) - elif args.from_file: - # Batch mode from file - - # Read prompts from file - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_lines = f.readlines() - - # Process prompts - prompts_data = preprocess_prompts_for_batch(prompt_lines, args) - process_batch_prompts(prompts_data, args) - - elif args.interactive: - # Interactive mode - process_interactive(args) - else: - # Single prompt mode (original behavior) + tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( + qwen3_path=args.text_encoder, t5_tokenizer_path=None, qwen3_max_length=512, t5_max_length=512 + ) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) - # Generate latent - gen_settings = get_generation_settings(args) + encoding_strategy = strategy_anima.AnimaTextEncodingStrategy() + strategy_base.TextEncodingStrategy.set_strategy(encoding_strategy) - # For single mode, precomputed data is None, shared_models is None. - # generate will load all necessary models (Text Encoders, DiT). - latent = generate(args, gen_settings) - # print(f"Generated latent shape: {latent.shape}") - # if args.save_merged_model: - # return + if args.from_file: + # Batch mode from file - clean_memory_on_device(device) + # Read prompts from file + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_lines = f.readlines() - # Save latent and video - vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + # Process prompts + prompts_data = preprocess_prompts_for_batch(prompt_lines, args) + process_batch_prompts(prompts_data, args) - vae.eval() - save_output(args, vae, latent, device) + elif args.interactive: + # Interactive mode + process_interactive(args) + + else: + # Single prompt mode (original behavior) + + # Generate latent + gen_settings = get_generation_settings(args) + + # For single mode, precomputed data is None, shared_models is None. + # generate will load all necessary models (Text Encoders, DiT). + latent = generate(args, gen_settings) + # print(f"Generated latent shape: {latent.shape}") + # if args.save_merged_model: + # return + + clean_memory_on_device(device) + + # Save latent and video + vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + + vae.eval() + save_output(args, vae, latent, device) logger.info("Done!")