feat: add Chroma model support with CLIP-L token handling

Co-authored-by: aider (deepseek/deepseek-chat) <aider@aider.chat>
This commit is contained in:
johnr14
2025-09-20 08:24:12 -04:00
parent f5d44fd487
commit af63e5422d
3 changed files with 43 additions and 24 deletions

View File

@@ -162,22 +162,28 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_tokenize_strategy(self, args):
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
# Instead, we analyze the checkpoint state to determine if it is schnell.
if args.model_type != "chroma":
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
else:
if args.model_type == "chroma":
is_schnell = False
self.is_schnell = is_schnell
if args.t5xxl_max_token_length is None:
if self.is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
self.is_schnell = is_schnell
t5xxl_max_token_length = args.t5xxl_max_token_length or 512
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
# Chroma doesn't use CLIP-L
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=False)
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
self.is_schnell = is_schnell
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
if args.t5xxl_max_token_length is None:
if self.is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
# FLUX models use both CLIP-L and T5
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir, use_clip_l=True)
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]

View File

@@ -21,19 +21,25 @@ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None, use_clip_l: bool = True) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.use_clip_l = use_clip_l
if self.use_clip_l:
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
if self.use_clip_l:
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
l_tokens = l_tokens["input_ids"]
else:
# For Chroma, return None for CLIP-L tokens
l_tokens = None
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, t5_tokens, t5_attn_mask]
@@ -63,24 +69,27 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
# clip_l is None when using T5 only
# Handle Chroma case where CLIP-L is not used
if clip_l is not None and l_tokens is not None:
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
else:
l_pooled = None
# For Chroma, create a dummy tensor with the right shape
if t5_tokens is not None:
batch_size = t5_tokens.shape[0]
l_pooled = torch.zeros(batch_size, 768, device=t5_tokens.device, dtype=torch.float32)
else:
l_pooled = None
# t5xxl is None when using CLIP only
if t5xxl is not None and t5_tokens is not None:
# t5_out is [b, max length, 4096]
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
# if zero_pad_t5_output:
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
t5_attn_mask = None
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer

View File

@@ -427,7 +427,11 @@ class NetworkTrainer:
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
# Handle Chroma case where CLIP-L tokens might be None
input_ids = []
for ids in batch["input_ids_list"]:
if ids is not None: # Skip None values (CLIP-L tokens for Chroma)
input_ids.append(ids.to(accelerator.device))
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),