mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix max_token_length not works for sdxl
This commit is contained in:
@@ -11,7 +11,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sampling_warning_showed = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args)
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -119,10 +119,10 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
encoder_hidden_states2 = []
|
||||
pool2 = []
|
||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||
input_id1 = input_id1.squeeze(0)
|
||||
input_id2 = input_id2.squeeze(0)
|
||||
encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())])
|
||||
hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())]
|
||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
|
||||
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
|
||||
encoder_hidden_states2.append(hidden_states2)
|
||||
pool2.append(p2)
|
||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||
|
||||
Reference in New Issue
Block a user