mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix max_token_length not works for sdxl
This commit is contained in:
@@ -382,10 +382,10 @@ def train(args):
|
||||
encoder_hidden_states2 = []
|
||||
pool2 = []
|
||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||
input_id1 = input_id1.squeeze(0)
|
||||
input_id2 = input_id2.squeeze(0)
|
||||
encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())])
|
||||
hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())]
|
||||
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
|
||||
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
|
||||
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
|
||||
encoder_hidden_states2.append(hidden_states2)
|
||||
pool2.append(p2)
|
||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||
|
||||
Reference in New Issue
Block a user