fix max_token_length not works for sdxl

This commit is contained in:
Kohya S
2023-06-29 13:02:19 +09:00
parent 8521ab7990
commit d395bc0647
5 changed files with 39 additions and 33 deletions

View File

@@ -126,6 +126,8 @@ def load_tokenizers(args: argparse.Namespace):
def get_hidden_states(
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
@@ -138,6 +140,11 @@ def get_hidden_states(
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["text_embeds"]
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if args.max_token_length is None else args.max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if args.max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
@@ -151,14 +158,19 @@ def get_hidden_states(
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
if weight_dtype is not None:
# this is required for additional network training
hidden_states1 = hidden_states1.to(weight_dtype)
@@ -313,37 +325,30 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
# split batch to avoid OOM
# TODO specify batch size by args
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
# remove input_ids already in cache
input_ids1 = input_ids1.squeeze(0)
input_ids2 = input_ids2.squeeze(0)
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache]
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache]
assert len(input_ids1) == len(input_ids2)
if len(input_ids1) == 0:
input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2_cache_key = tuple(input_id2.flatten().tolist())
if input_id1_cache_key in text_encoder1_cache:
assert input_id2_cache_key in text_encoder2_cache
continue
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
with torch.no_grad():
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
args,
input_ids1,
input_ids2,
input_id1,
input_id2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu")
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu")
pool2 = pool2.to("cpu")
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip(
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2
):
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
return text_encoder1_cache, text_encoder2_cache

View File

@@ -3358,6 +3358,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
if input_ids.size()[-1] != tokenizer.model_max_length:
return text_encoder(input_ids)[0]
# input_ids: b,n,77
b_size = input_ids.size()[0]
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77