mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix max_token_length not works for sdxl
This commit is contained in:
@@ -126,6 +126,8 @@ def load_tokenizers(args: argparse.Namespace):
|
|||||||
def get_hidden_states(
|
def get_hidden_states(
|
||||||
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||||
):
|
):
|
||||||
|
# input_ids: b,n,77 -> b*n, 77
|
||||||
|
b_size = input_ids1.size()[0]
|
||||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||||
|
|
||||||
@@ -138,6 +140,11 @@ def get_hidden_states(
|
|||||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
pool2 = enc_out["text_embeds"]
|
pool2 = enc_out["text_embeds"]
|
||||||
|
|
||||||
|
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||||
|
n_size = 1 if args.max_token_length is None else args.max_token_length // 75
|
||||||
|
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||||
|
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
# bs*3, 77, 768 or 1024
|
# bs*3, 77, 768 or 1024
|
||||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
@@ -151,14 +158,19 @@ def get_hidden_states(
|
|||||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||||
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
|
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
|
||||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||||
if i > 0:
|
# this causes an error:
|
||||||
for j in range(len(chunk)):
|
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||||
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
# if i > 1:
|
||||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
# for j in range(len(chunk)): # batch_size
|
||||||
|
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||||
|
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||||
hidden_states2 = torch.cat(states_list, dim=1)
|
hidden_states2 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
# pool はnの最初のものを使う
|
||||||
|
pool2 = pool2[::n_size]
|
||||||
|
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None:
|
||||||
# this is required for additional network training
|
# this is required for additional network training
|
||||||
hidden_states1 = hidden_states1.to(weight_dtype)
|
hidden_states1 = hidden_states1.to(weight_dtype)
|
||||||
@@ -313,37 +325,30 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
|
|||||||
|
|
||||||
# split batch to avoid OOM
|
# split batch to avoid OOM
|
||||||
# TODO specify batch size by args
|
# TODO specify batch size by args
|
||||||
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
||||||
# remove input_ids already in cache
|
# remove input_ids already in cache
|
||||||
input_ids1 = input_ids1.squeeze(0)
|
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||||
input_ids2 = input_ids2.squeeze(0)
|
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||||
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache]
|
if input_id1_cache_key in text_encoder1_cache:
|
||||||
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache]
|
assert input_id2_cache_key in text_encoder2_cache
|
||||||
assert len(input_ids1) == len(input_ids2)
|
|
||||||
if len(input_ids1) == 0:
|
|
||||||
continue
|
continue
|
||||||
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
|
|
||||||
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
||||||
args,
|
args,
|
||||||
input_ids1,
|
input_id1,
|
||||||
input_ids2,
|
input_id2,
|
||||||
tokenizer1,
|
tokenizer1,
|
||||||
tokenizer2,
|
tokenizer2,
|
||||||
text_encoder1,
|
text_encoder1,
|
||||||
text_encoder2,
|
text_encoder2,
|
||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
)
|
)
|
||||||
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu")
|
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
|
||||||
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu")
|
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
|
||||||
pool2 = pool2.to("cpu")
|
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
|
||||||
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip(
|
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
|
||||||
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2
|
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
|
||||||
):
|
|
||||||
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
|
|
||||||
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
|
|
||||||
return text_encoder1_cache, text_encoder2_cache
|
return text_encoder1_cache, text_encoder2_cache
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3358,6 +3358,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
|||||||
if input_ids.size()[-1] != tokenizer.model_max_length:
|
if input_ids.size()[-1] != tokenizer.model_max_length:
|
||||||
return text_encoder(input_ids)[0]
|
return text_encoder(input_ids)[0]
|
||||||
|
|
||||||
|
# input_ids: b,n,77
|
||||||
b_size = input_ids.size()[0]
|
b_size = input_ids.size()[0]
|
||||||
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
||||||
|
|
||||||
|
|||||||
@@ -382,10 +382,10 @@ def train(args):
|
|||||||
encoder_hidden_states2 = []
|
encoder_hidden_states2 = []
|
||||||
pool2 = []
|
pool2 = []
|
||||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||||
input_id1 = input_id1.squeeze(0)
|
input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
|
||||||
input_id2 = input_id2.squeeze(0)
|
input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
|
||||||
encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())])
|
encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
|
||||||
hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())]
|
hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
|
||||||
encoder_hidden_states2.append(hidden_states2)
|
encoder_hidden_states2.append(hidden_states2)
|
||||||
pool2.append(p2)
|
pool2.append(p2)
|
||||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.sampling_warning_showed = False
|
self.sampling_warning_showed = False
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group):
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
super().assert_extra_args(args)
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
sdxl_train_util.verify_sdxl_training_args(args)
|
sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
@@ -119,10 +119,10 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
encoder_hidden_states2 = []
|
encoder_hidden_states2 = []
|
||||||
pool2 = []
|
pool2 = []
|
||||||
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||||
input_id1 = input_id1.squeeze(0)
|
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||||
input_id2 = input_id2.squeeze(0)
|
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||||
encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())])
|
encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
|
||||||
hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())]
|
hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
|
||||||
encoder_hidden_states2.append(hidden_states2)
|
encoder_hidden_states2.append(hidden_states2)
|
||||||
pool2.append(p2)
|
pool2.append(p2)
|
||||||
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
def assert_extra_args(self, args):
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_target_model(self, args, weight_dtype, accelerator):
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
|||||||
Reference in New Issue
Block a user