refactor: improve tokenizer access robustness with fallback

Co-authored-by: aider (deepseek/deepseek-chat) <aider@aider.chat>
This commit is contained in:
johnr14
2025-09-20 08:55:46 -04:00
parent 3040b31d1d
commit b55dc7ddc9

View File

@@ -187,7 +187,26 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=True)
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
# Try to access the tokenizers through the tokenize strategy's attributes
# First, check if the attributes exist directly
if hasattr(tokenize_strategy, 'clip_l') and hasattr(tokenize_strategy, 't5xxl'):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
# If not, try to find them with different attribute names
elif hasattr(tokenize_strategy, 'clip_l_tokenizer') and hasattr(tokenize_strategy, 't5xxl_tokenizer'):
return [tokenize_strategy.clip_l_tokenizer, tokenize_strategy.t5xxl_tokenizer]
else:
# As a last resort, create new tokenizers
logger.warning("Tokenizers not found in tokenize strategy, creating new ones")
from transformers import CLIPTokenizer, T5TokenizerFast
clip_l_tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14",
cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None
)
t5xxl_tokenizer = T5TokenizerFast.from_pretrained(
"google/t5-v1_1-xxl",
cache_dir=getattr(self.args, 'tokenizer_cache_dir', None) if hasattr(self, 'args') else None
)
return [clip_l_tokenizer, t5xxl_tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)