fix: handle Chroma case where input_ids_list is None

Co-authored-by: aider (deepseek/deepseek-chat) <aider@aider.chat>
This commit is contained in:
johnr14
2025-09-20 09:05:29 -04:00
parent b55dc7ddc9
commit c9f76284aa

View File

@@ -413,8 +413,22 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
else:
# For debugging
logger.debug(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}")
else:
# For debugging
print(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}")
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# For Chroma, text_encoder_conds might be set up differently
# Check if we need to encode text encoders
need_to_encode = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder
# Also check if input_ids_list is None (for Chroma)
if "input_ids_list" in batch and batch["input_ids_list"] is None:
# If input_ids_list is None, we might already have the text encoder outputs cached
need_to_encode = False
if need_to_encode:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
@@ -428,15 +442,26 @@ class NetworkTrainer:
)
else:
# Handle Chroma case where CLIP-L tokens might be None
input_ids = []
for ids in batch["input_ids_list"]:
if ids is not None: # Skip None values (CLIP-L tokens for Chroma)
input_ids.append(ids.to(accelerator.device))
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
# Check if input_ids_list exists and is not None
if "input_ids_list" in batch and batch["input_ids_list"] is not None:
input_ids = []
for ids in batch["input_ids_list"]:
if ids is not None: # Skip None values (CLIP-L tokens for Chroma)
input_ids.append(ids.to(accelerator.device))
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
else:
# For Chroma, we might have a different way to get the input ids
# Since input_ids_list is None, we need to handle this case
# Let's assume the text encoding strategy can handle this
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
[], # Pass empty list or handle differently
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]