mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
fix: use strategy instead of using tokenizers directly
This commit is contained in:
@@ -255,14 +255,6 @@ def load_dit_model(
|
||||
lora_weights_list=lora_weights_list,
|
||||
lora_multipliers=args.lora_multiplier,
|
||||
)
|
||||
# model = anima_utils.load_anima_dit(
|
||||
# args.dit,
|
||||
# dtype=loading_weight_dtype,
|
||||
# device=loading_device,
|
||||
# transformer_dtype=loading_weight_dtype,
|
||||
# llm_adapter_path=None, # getattr(args, "llm_adapter_path", None),
|
||||
# disable_mmap=False, # getattr(args, "disable_mmap_load_safetensors", False),
|
||||
# )
|
||||
if not args.fp8_scaled:
|
||||
# simple cast to dit_weight_dtype
|
||||
target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
|
||||
@@ -311,9 +303,7 @@ def prepare_text_inputs(
|
||||
conds_cache = {}
|
||||
text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
|
||||
if shared_models is not None:
|
||||
tokenizer = shared_models.get("tokenizer")
|
||||
text_encoder = shared_models.get("text_encoder")
|
||||
t5xxl_tokenizer = shared_models.get("t5xxl_tokenizer")
|
||||
|
||||
if "conds_cache" in shared_models: # Use shared cache if available
|
||||
conds_cache = shared_models["conds_cache"]
|
||||
@@ -321,32 +311,19 @@ def prepare_text_inputs(
|
||||
# text_encoder is on device (batched inference) or CPU (interactive inference)
|
||||
else: # Load if not in shared_models
|
||||
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
|
||||
# tokenizer, text_encoder = anima_text_encoder.load_qwen3(
|
||||
# args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device, disable_mmap=True
|
||||
# )
|
||||
# t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer()
|
||||
text_encoder, _ = anima_utils.load_qwen3_text_encoder(
|
||||
args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device
|
||||
)
|
||||
text_encoder.eval()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
if tokenize_strategy is None:
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_path=args.text_encoder,
|
||||
t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None),
|
||||
qwen3_max_length=512, # args.qwen3_max_token_length,
|
||||
t5_max_length=512, # args.t5_max_token_length,
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
# Store references so load_target_model can reuse them
|
||||
tokenizer = tokenize_strategy.qwen3_tokenizer
|
||||
t5xxl_tokenizer = tokenize_strategy.t5_tokenizer
|
||||
|
||||
# Store original devices to move back later if they were shared. This does nothing if shared_models is None
|
||||
text_encoder_original_device = text_encoder.device if text_encoder else None
|
||||
|
||||
# Ensure text_encoder is not None before proceeding
|
||||
if not text_encoder or not tokenizer or not t5xxl_tokenizer:
|
||||
raise ValueError("Text encoder or tokenizer is not loaded properly.")
|
||||
if not text_encoder:
|
||||
raise ValueError("Text encoder is not loaded properly.")
|
||||
|
||||
# Define a function to move models to device if needed
|
||||
# This is to avoid moving models if not needed, especially in interactive mode
|
||||
@@ -372,14 +349,14 @@ def prepare_text_inputs(
|
||||
else:
|
||||
move_models_to_device_if_needed()
|
||||
|
||||
encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
with torch.no_grad():
|
||||
# embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt)
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||||
crossattn_emb = anima.preprocess_text_embeds(
|
||||
crossattn_emb = anima._preprocess_text_embeds(
|
||||
source_hidden_states=embed[0].to(anima.device),
|
||||
target_input_ids=embed[2].to(anima.device),
|
||||
target_attention_mask=embed[3].to(anima.device),
|
||||
@@ -402,7 +379,7 @@ def prepare_text_inputs(
|
||||
# negative_embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, negative_prompt)
|
||||
tokens = tokenize_strategy.tokenize(negative_prompt)
|
||||
negative_embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||||
crossattn_emb = anima.preprocess_text_embeds(
|
||||
crossattn_emb = anima._preprocess_text_embeds(
|
||||
source_hidden_states=negative_embed[0].to(anima.device),
|
||||
target_input_ids=negative_embed[2].to(anima.device),
|
||||
target_attention_mask=negative_embed[3].to(anima.device),
|
||||
@@ -416,7 +393,7 @@ def prepare_text_inputs(
|
||||
|
||||
if not (shared_models and "text_encoder" in shared_models): # if loaded locally
|
||||
# There is a bug text_encoder is not freed from GPU memory when text encoder is fp8
|
||||
del tokenizer, text_encoder, t5xxl_tokenizer
|
||||
del text_encoder
|
||||
gc.collect() # This may force Text Encoder to be freed from GPU memory
|
||||
else: # if shared, move back to original device (likely CPU)
|
||||
if text_encoder:
|
||||
@@ -719,13 +696,8 @@ def load_shared_models(args: argparse.Namespace) -> Dict:
|
||||
shared_models = {}
|
||||
# Load text encoders to CPU
|
||||
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
|
||||
tokenizer, text_encoder = anima_text_encoder.load_qwen3(
|
||||
args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True
|
||||
)
|
||||
t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer()
|
||||
shared_models["tokenizer"] = tokenizer
|
||||
text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
|
||||
shared_models["text_encoder"] = text_encoder
|
||||
shared_models["t5xxl_tokenizer"] = t5xxl_tokenizer
|
||||
return shared_models
|
||||
|
||||
|
||||
@@ -766,10 +738,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
|
||||
# Text Encoder loaded to CPU by load_text_encoder
|
||||
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
|
||||
tokenizer_batch, text_encoder_batch = anima_text_encoder.load_qwen3(
|
||||
args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True
|
||||
)
|
||||
t5xxl_tokenizer_batch = anima_text_encoder.load_t5xxl_tokenizer()
|
||||
text_encoder_batch, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
|
||||
|
||||
# Text Encoder to device for this phase
|
||||
text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
|
||||
@@ -780,9 +749,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
|
||||
logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...")
|
||||
temp_shared_models_txt = {
|
||||
"tokenizer": tokenizer_batch,
|
||||
"text_encoder": text_encoder_batch, # on GPU if not text_encoder_cpu
|
||||
"t5xxl_tokenizer": t5xxl_tokenizer_batch,
|
||||
"conds_cache": conds_cache_batch,
|
||||
}
|
||||
|
||||
@@ -795,7 +762,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
all_precomputed_text_data.append(text_data)
|
||||
|
||||
# Models should be removed from device after prepare_text_inputs
|
||||
del tokenizer_batch, text_encoder_batch, t5xxl_tokenizer_batch, temp_shared_models_txt, conds_cache_batch
|
||||
del text_encoder_batch, temp_shared_models_txt, conds_cache_batch
|
||||
gc.collect() # Force cleanup of Text Encoder from GPU memory
|
||||
clean_memory_on_device(device)
|
||||
|
||||
@@ -1001,41 +968,50 @@ def main():
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device, original_base_names[i])
|
||||
|
||||
elif args.from_file:
|
||||
# Batch mode from file
|
||||
|
||||
# Read prompts from file
|
||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||
prompt_lines = f.readlines()
|
||||
|
||||
# Process prompts
|
||||
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
|
||||
process_batch_prompts(prompts_data, args)
|
||||
|
||||
elif args.interactive:
|
||||
# Interactive mode
|
||||
process_interactive(args)
|
||||
|
||||
else:
|
||||
# Single prompt mode (original behavior)
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_path=args.text_encoder, t5_tokenizer_path=None, qwen3_max_length=512, t5_max_length=512
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
# Generate latent
|
||||
gen_settings = get_generation_settings(args)
|
||||
encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(encoding_strategy)
|
||||
|
||||
# For single mode, precomputed data is None, shared_models is None.
|
||||
# generate will load all necessary models (Text Encoders, DiT).
|
||||
latent = generate(args, gen_settings)
|
||||
# print(f"Generated latent shape: {latent.shape}")
|
||||
# if args.save_merged_model:
|
||||
# return
|
||||
if args.from_file:
|
||||
# Batch mode from file
|
||||
|
||||
clean_memory_on_device(device)
|
||||
# Read prompts from file
|
||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||
prompt_lines = f.readlines()
|
||||
|
||||
# Save latent and video
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
# Process prompts
|
||||
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
|
||||
process_batch_prompts(prompts_data, args)
|
||||
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device)
|
||||
elif args.interactive:
|
||||
# Interactive mode
|
||||
process_interactive(args)
|
||||
|
||||
else:
|
||||
# Single prompt mode (original behavior)
|
||||
|
||||
# Generate latent
|
||||
gen_settings = get_generation_settings(args)
|
||||
|
||||
# For single mode, precomputed data is None, shared_models is None.
|
||||
# generate will load all necessary models (Text Encoders, DiT).
|
||||
latent = generate(args, gen_settings)
|
||||
# print(f"Generated latent shape: {latent.shape}")
|
||||
# if args.save_merged_model:
|
||||
# return
|
||||
|
||||
clean_memory_on_device(device)
|
||||
|
||||
# Save latent and video
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device)
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user