fix max_token_length not works for sdxl

This commit is contained in:
Kohya S
2023-06-29 13:02:19 +09:00
parent 8521ab7990
commit d395bc0647
5 changed files with 39 additions and 33 deletions

View File

@@ -11,7 +11,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.sampling_warning_showed = False
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args)
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
@@ -119,10 +119,10 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
encoder_hidden_states2 = []
pool2 = []
for input_id1, input_id2 in zip(input_ids1, input_ids2):
input_id1 = input_id1.squeeze(0)
input_id2 = input_id2.squeeze(0)
encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())])
hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())]
input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2_cache_key = tuple(input_id2.flatten().tolist())
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
encoder_hidden_states2.append(hidden_states2)
pool2.append(p2)
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)