fix max_token_length not works for sdxl

This commit is contained in:
Kohya S
2023-06-29 13:02:19 +09:00
parent 8521ab7990
commit d395bc0647
5 changed files with 39 additions and 33 deletions

View File

@@ -382,10 +382,10 @@ def train(args):
encoder_hidden_states2 = []
pool2 = []
for input_id1, input_id2 in zip(input_ids1, input_ids2):
input_id1 = input_id1.squeeze(0)
input_id2 = input_id2.squeeze(0)
encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())])
hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())]
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
encoder_hidden_states2.append(hidden_states2)
pool2.append(p2)
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)