fix: use strategy instead of using tokenizers directly

This commit is contained in:
kohya-ss
2026-02-09 12:42:16 +09:00
parent 35161b044c
commit a1e3d02259

View File

@@ -255,14 +255,6 @@ def load_dit_model(
lora_weights_list=lora_weights_list,
lora_multipliers=args.lora_multiplier,
)
# model = anima_utils.load_anima_dit(
# args.dit,
# dtype=loading_weight_dtype,
# device=loading_device,
# transformer_dtype=loading_weight_dtype,
# llm_adapter_path=None, # getattr(args, "llm_adapter_path", None),
# disable_mmap=False, # getattr(args, "disable_mmap_load_safetensors", False),
# )
if not args.fp8_scaled:
# simple cast to dit_weight_dtype
target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
@@ -311,9 +303,7 @@ def prepare_text_inputs(
conds_cache = {}
text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
if shared_models is not None:
tokenizer = shared_models.get("tokenizer")
text_encoder = shared_models.get("text_encoder")
t5xxl_tokenizer = shared_models.get("t5xxl_tokenizer")
if "conds_cache" in shared_models: # Use shared cache if available
conds_cache = shared_models["conds_cache"]
@@ -321,32 +311,19 @@ def prepare_text_inputs(
# text_encoder is on device (batched inference) or CPU (interactive inference)
else: # Load if not in shared_models
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
# tokenizer, text_encoder = anima_text_encoder.load_qwen3(
# args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device, disable_mmap=True
# )
# t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer()
text_encoder, _ = anima_utils.load_qwen3_text_encoder(
args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device
)
text_encoder.eval()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
if tokenize_strategy is None:
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.text_encoder,
t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None),
qwen3_max_length=512, # args.qwen3_max_token_length,
t5_max_length=512, # args.t5_max_token_length,
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Store references so load_target_model can reuse them
tokenizer = tokenize_strategy.qwen3_tokenizer
t5xxl_tokenizer = tokenize_strategy.t5_tokenizer
# Store original devices to move back later if they were shared. This does nothing if shared_models is None
text_encoder_original_device = text_encoder.device if text_encoder else None
# Ensure text_encoder is not None before proceeding
if not text_encoder or not tokenizer or not t5xxl_tokenizer:
raise ValueError("Text encoder or tokenizer is not loaded properly.")
if not text_encoder:
raise ValueError("Text encoder is not loaded properly.")
# Define a function to move models to device if needed
# This is to avoid moving models if not needed, especially in interactive mode
@@ -372,14 +349,14 @@ def prepare_text_inputs(
else:
move_models_to_device_if_needed()
encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
with torch.no_grad():
# embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt)
tokens = tokenize_strategy.tokenize(prompt)
embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
crossattn_emb = anima.preprocess_text_embeds(
crossattn_emb = anima._preprocess_text_embeds(
source_hidden_states=embed[0].to(anima.device),
target_input_ids=embed[2].to(anima.device),
target_attention_mask=embed[3].to(anima.device),
@@ -402,7 +379,7 @@ def prepare_text_inputs(
# negative_embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, negative_prompt)
tokens = tokenize_strategy.tokenize(negative_prompt)
negative_embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
crossattn_emb = anima.preprocess_text_embeds(
crossattn_emb = anima._preprocess_text_embeds(
source_hidden_states=negative_embed[0].to(anima.device),
target_input_ids=negative_embed[2].to(anima.device),
target_attention_mask=negative_embed[3].to(anima.device),
@@ -416,7 +393,7 @@ def prepare_text_inputs(
if not (shared_models and "text_encoder" in shared_models): # if loaded locally
# There is a bug text_encoder is not freed from GPU memory when text encoder is fp8
del tokenizer, text_encoder, t5xxl_tokenizer
del text_encoder
gc.collect() # This may force Text Encoder to be freed from GPU memory
else: # if shared, move back to original device (likely CPU)
if text_encoder:
@@ -719,13 +696,8 @@ def load_shared_models(args: argparse.Namespace) -> Dict:
shared_models = {}
# Load text encoders to CPU
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
tokenizer, text_encoder = anima_text_encoder.load_qwen3(
args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True
)
t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer()
shared_models["tokenizer"] = tokenizer
text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
shared_models["text_encoder"] = text_encoder
shared_models["t5xxl_tokenizer"] = t5xxl_tokenizer
return shared_models
@@ -766,10 +738,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# Text Encoder loaded to CPU by load_text_encoder
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
tokenizer_batch, text_encoder_batch = anima_text_encoder.load_qwen3(
args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True
)
t5xxl_tokenizer_batch = anima_text_encoder.load_t5xxl_tokenizer()
text_encoder_batch, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
# Text Encoder to device for this phase
text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
@@ -780,9 +749,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...")
temp_shared_models_txt = {
"tokenizer": tokenizer_batch,
"text_encoder": text_encoder_batch, # on GPU if not text_encoder_cpu
"t5xxl_tokenizer": t5xxl_tokenizer_batch,
"conds_cache": conds_cache_batch,
}
@@ -795,7 +762,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
all_precomputed_text_data.append(text_data)
# Models should be removed from device after prepare_text_inputs
del tokenizer_batch, text_encoder_batch, t5xxl_tokenizer_batch, temp_shared_models_txt, conds_cache_batch
del text_encoder_batch, temp_shared_models_txt, conds_cache_batch
gc.collect() # Force cleanup of Text Encoder from GPU memory
clean_memory_on_device(device)
@@ -1001,41 +968,50 @@ def main():
vae.eval()
save_output(args, vae, latent, device, original_base_names[i])
elif args.from_file:
# Batch mode from file
# Read prompts from file
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_lines = f.readlines()
# Process prompts
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
process_batch_prompts(prompts_data, args)
elif args.interactive:
# Interactive mode
process_interactive(args)
else:
# Single prompt mode (original behavior)
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.text_encoder, t5_tokenizer_path=None, qwen3_max_length=512, t5_max_length=512
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Generate latent
gen_settings = get_generation_settings(args)
encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(encoding_strategy)
# For single mode, precomputed data is None, shared_models is None.
# generate will load all necessary models (Text Encoders, DiT).
latent = generate(args, gen_settings)
# print(f"Generated latent shape: {latent.shape}")
# if args.save_merged_model:
# return
if args.from_file:
# Batch mode from file
clean_memory_on_device(device)
# Read prompts from file
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_lines = f.readlines()
# Save latent and video
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
# Process prompts
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
process_batch_prompts(prompts_data, args)
vae.eval()
save_output(args, vae, latent, device)
elif args.interactive:
# Interactive mode
process_interactive(args)
else:
# Single prompt mode (original behavior)
# Generate latent
gen_settings = get_generation_settings(args)
# For single mode, precomputed data is None, shared_models is None.
# generate will load all necessary models (Text Encoders, DiT).
latent = generate(args, gen_settings)
# print(f"Generated latent shape: {latent.shape}")
# if args.save_merged_model:
# return
clean_memory_on_device(device)
# Save latent and video
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
vae.eval()
save_output(args, vae, latent, device)
logger.info("Done!")