fix max_token_length not works for sdxl

This commit is contained in:
Kohya S
2023-06-29 13:02:19 +09:00
parent 8521ab7990
commit d395bc0647
5 changed files with 39 additions and 33 deletions

View File

@@ -126,6 +126,8 @@ def load_tokenizers(args: argparse.Namespace):
def get_hidden_states( def get_hidden_states(
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
): ):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
@@ -138,6 +140,11 @@ def get_hidden_states(
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["text_embeds"] pool2 = enc_out["text_embeds"]
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if args.max_token_length is None else args.max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if args.max_token_length is not None: if args.max_token_length is not None:
# bs*3, 77, 768 or 1024 # bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す # encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
@@ -151,14 +158,19 @@ def get_hidden_states(
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS> states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer2.model_max_length): for i in range(1, args.max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0: # this causes an error:
for j in range(len(chunk)): # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン # if i > 1:
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする # for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1) hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
if weight_dtype is not None: if weight_dtype is not None:
# this is required for additional network training # this is required for additional network training
hidden_states1 = hidden_states1.to(weight_dtype) hidden_states1 = hidden_states1.to(weight_dtype)
@@ -313,37 +325,30 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
# split batch to avoid OOM # split batch to avoid OOM
# TODO specify batch size by args # TODO specify batch size by args
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)): for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
# remove input_ids already in cache # remove input_ids already in cache
input_ids1 = input_ids1.squeeze(0) input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_ids2 = input_ids2.squeeze(0) input_id2_cache_key = tuple(input_id2.flatten().tolist())
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache] if input_id1_cache_key in text_encoder1_cache:
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache] assert input_id2_cache_key in text_encoder2_cache
assert len(input_ids1) == len(input_ids2)
if len(input_ids1) == 0:
continue continue
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
with torch.no_grad(): with torch.no_grad():
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states( encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
args, args,
input_ids1, input_id1,
input_ids2, input_id2,
tokenizer1, tokenizer1,
tokenizer2, tokenizer2,
text_encoder1, text_encoder1,
text_encoder2, text_encoder2,
None if not args.full_fp16 else weight_dtype, None if not args.full_fp16 else weight_dtype,
) )
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu") encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu") encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
pool2 = pool2.to("cpu") pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip( text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2 text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
):
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
return text_encoder1_cache, text_encoder2_cache return text_encoder1_cache, text_encoder2_cache

View File

@@ -3358,6 +3358,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
if input_ids.size()[-1] != tokenizer.model_max_length: if input_ids.size()[-1] != tokenizer.model_max_length:
return text_encoder(input_ids)[0] return text_encoder(input_ids)[0]
# input_ids: b,n,77
b_size = input_ids.size()[0] b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77

View File

@@ -382,10 +382,10 @@ def train(args):
encoder_hidden_states2 = [] encoder_hidden_states2 = []
pool2 = [] pool2 = []
for input_id1, input_id2 in zip(input_ids1, input_ids2): for input_id1, input_id2 in zip(input_ids1, input_ids2):
input_id1 = input_id1.squeeze(0) input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist())
input_id2 = input_id2.squeeze(0) input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist())
encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())]) encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key])
hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())] hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key]
encoder_hidden_states2.append(hidden_states2) encoder_hidden_states2.append(hidden_states2)
pool2.append(p2) pool2.append(p2)
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)

View File

@@ -11,7 +11,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.sampling_warning_showed = False self.sampling_warning_showed = False
def assert_extra_args(self, args, train_dataset_group): def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args) super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args) sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
@@ -119,10 +119,10 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
encoder_hidden_states2 = [] encoder_hidden_states2 = []
pool2 = [] pool2 = []
for input_id1, input_id2 in zip(input_ids1, input_ids2): for input_id1, input_id2 in zip(input_ids1, input_ids2):
input_id1 = input_id1.squeeze(0) input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2 = input_id2.squeeze(0) input_id2_cache_key = tuple(input_id2.flatten().tolist())
encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())]) encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key])
hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())] hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key]
encoder_hidden_states2.append(hidden_states2) encoder_hidden_states2.append(hidden_states2)
pool2.append(p2) pool2.append(p2)
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)

View File

@@ -83,7 +83,7 @@ class NetworkTrainer:
return logs return logs
def assert_extra_args(self, args): def assert_extra_args(self, args, train_dataset_group):
pass pass
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):