mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Merge 10ff863b8a into 1dae34b0af
This commit is contained in:
@@ -400,8 +400,19 @@ class NetworkTrainer:
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
else:
|
||||
# For debugging
|
||||
logger.debug(f"text_encoder_outputs_list is None, batch keys: {list(batch.keys())}")
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# For Chroma, text_encoder_conds might be set up differently
|
||||
# Check if we need to encode text encoders
|
||||
need_to_encode = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder
|
||||
# Also check if input_ids_list is None (for Chroma)
|
||||
if "input_ids_list" in batch and batch["input_ids_list"] is None:
|
||||
# If input_ids_list is None, we might already have the text encoder outputs cached
|
||||
need_to_encode = False
|
||||
|
||||
if need_to_encode:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
@@ -414,12 +425,27 @@ class NetworkTrainer:
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
# Handle Chroma case where CLIP-L tokens might be None
|
||||
# Check if input_ids_list exists and is not None
|
||||
if "input_ids_list" in batch and batch["input_ids_list"] is not None:
|
||||
input_ids = []
|
||||
for ids in batch["input_ids_list"]:
|
||||
if ids is not None: # Skip None values (CLIP-L tokens for Chroma)
|
||||
input_ids.append(ids.to(accelerator.device))
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
else:
|
||||
# For Chroma, we might have a different way to get the input ids
|
||||
# Since input_ids_list is None, we need to handle this case
|
||||
# Let's assume the text encoding strategy can handle this
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
[], # Pass empty list or handle differently
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
@@ -472,6 +498,7 @@ class NetworkTrainer:
|
||||
return True # default for other than HunyuanImage
|
||||
|
||||
def train(self, args):
|
||||
self.args = args # store args for later use
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
Reference in New Issue
Block a user