mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix max_token_length not works for sdxl
This commit is contained in:
@@ -126,6 +126,8 @@ def load_tokenizers(args: argparse.Namespace):
|
||||
def get_hidden_states(
|
||||
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
|
||||
@@ -138,6 +140,11 @@ def get_hidden_states(
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool2 = enc_out["text_embeds"]
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if args.max_token_length is None else args.max_token_length // 75
|
||||
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||
|
||||
if args.max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
@@ -151,14 +158,19 @@ def get_hidden_states(
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
# this causes an error:
|
||||
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# if i > 1:
|
||||
# for j in range(len(chunk)): # batch_size
|
||||
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
# pool はnの最初のものを使う
|
||||
pool2 = pool2[::n_size]
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
hidden_states1 = hidden_states1.to(weight_dtype)
|
||||
@@ -313,37 +325,30 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
|
||||
|
||||
# split batch to avoid OOM
|
||||
# TODO specify batch size by args
|
||||
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
||||
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
||||
# remove input_ids already in cache
|
||||
input_ids1 = input_ids1.squeeze(0)
|
||||
input_ids2 = input_ids2.squeeze(0)
|
||||
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache]
|
||||
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache]
|
||||
assert len(input_ids1) == len(input_ids2)
|
||||
if len(input_ids1) == 0:
|
||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||
if input_id1_cache_key in text_encoder1_cache:
|
||||
assert input_id2_cache_key in text_encoder2_cache
|
||||
continue
|
||||
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
|
||||
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
||||
args,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
input_id1,
|
||||
input_id2,
|
||||
tokenizer1,
|
||||
tokenizer2,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu")
|
||||
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu")
|
||||
pool2 = pool2.to("cpu")
|
||||
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip(
|
||||
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
):
|
||||
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
|
||||
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
|
||||
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
|
||||
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
|
||||
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
|
||||
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
|
||||
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
|
||||
return text_encoder1_cache, text_encoder2_cache
|
||||
|
||||
|
||||
|
||||
@@ -3358,6 +3358,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
if input_ids.size()[-1] != tokenizer.model_max_length:
|
||||
return text_encoder(input_ids)[0]
|
||||
|
||||
# input_ids: b,n,77
|
||||
b_size = input_ids.size()[0]
|
||||
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
||||
|
||||
|
||||
@@ -382,10 +382,10 @@ def train(args):
|
||||
encoder_hidden_states2 = []
|
||||
pool2 = []
|
||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||
input_id1 = input_id1.squeeze(0)
|
||||
input_id2 = input_id2.squeeze(0)
|
||||
encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())])
|
||||
hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())]
|
||||
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
|
||||
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
|
||||
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
|
||||
encoder_hidden_states2.append(hidden_states2)
|
||||
pool2.append(p2)
|
||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||
|
||||
@@ -11,7 +11,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sampling_warning_showed = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args)
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -119,10 +119,10 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
encoder_hidden_states2 = []
|
||||
pool2 = []
|
||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||
input_id1 = input_id1.squeeze(0)
|
||||
input_id2 = input_id2.squeeze(0)
|
||||
encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())])
|
||||
hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())]
|
||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
|
||||
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
|
||||
encoder_hidden_states2.append(hidden_states2)
|
||||
pool2.append(p2)
|
||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||
|
||||
@@ -83,7 +83,7 @@ class NetworkTrainer:
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args):
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
pass
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
|
||||
Reference in New Issue
Block a user