diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 46756b86..e3b502c2 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -126,6 +126,8 @@ def load_tokenizers(args: argparse.Namespace): def get_hidden_states( args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 @@ -138,6 +140,11 @@ def get_hidden_states( hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer pool2 = enc_out["text_embeds"] + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if args.max_token_length is None else args.max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + if args.max_token_length is not None: # bs*3, 77, 768 or 1024 # encoder1: ... の三連を ... へ戻す @@ -151,14 +158,19 @@ def get_hidden_states( states_list = [hidden_states2[:, 0].unsqueeze(1)] # for i in range(1, args.max_token_length, tokenizer2.model_max_length): chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする states_list.append(chunk) # の後から の前まで states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか hidden_states2 = torch.cat(states_list, dim=1) + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + if weight_dtype is not None: # this is required for additional network training hidden_states1 = hidden_states1.to(weight_dtype) @@ -313,37 +325,30 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat # split batch to avoid OOM # TODO specify batch size by args - for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)): + for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)): # remove input_ids already in cache - input_ids1 = input_ids1.squeeze(0) - input_ids2 = input_ids2.squeeze(0) - input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache] - input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache] - assert len(input_ids1) == len(input_ids2) - if len(input_ids1) == 0: + input_id1_cache_key = tuple(input_id1.flatten().tolist()) + input_id2_cache_key = tuple(input_id2.flatten().tolist()) + if input_id1_cache_key in text_encoder1_cache: + assert input_id2_cache_key in text_encoder2_cache continue - input_ids1 = torch.stack(input_ids1).to(accelerator.device) - input_ids2 = torch.stack(input_ids2).to(accelerator.device) with torch.no_grad(): encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states( args, - input_ids1, - input_ids2, + input_id1, + input_id2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, None if not args.full_fp16 else weight_dtype, ) - encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu") - encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu") - pool2 = pool2.to("cpu") - for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip( - input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2 - ): - text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1 - text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2) + encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768 + encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280 + pool2 = pool2.detach().to("cpu").squeeze(0) # 1280 + text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1 + text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2) return text_encoder1_cache, text_encoder2_cache diff --git a/library/train_util.py b/library/train_util.py index c27e1b27..43f55353 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3358,6 +3358,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod if input_ids.size()[-1] != tokenizer.model_max_length: return text_encoder(input_ids)[0] + # input_ids: b,n,77 b_size = input_ids.size()[0] input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 diff --git a/sdxl_train.py b/sdxl_train.py index f640580a..66b3c76d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -382,10 +382,10 @@ def train(args): encoder_hidden_states2 = [] pool2 = [] for input_id1, input_id2 in zip(input_ids1, input_ids2): - input_id1 = input_id1.squeeze(0) - input_id2 = input_id2.squeeze(0) - encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())]) - hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())] + input_id1_cache_key = tuple(input_id1.squeeze(0).flatten().tolist()) + input_id2_cache_key = tuple(input_id2.squeeze(0).flatten().tolist()) + encoder_hidden_states1.append(text_encoder1_cache[input_id1_cache_key]) + hidden_states2, p2 = text_encoder2_cache[input_id2_cache_key] encoder_hidden_states2.append(hidden_states2) pool2.append(p2) encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index fb445fb7..8f52fe5d 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -11,7 +11,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.sampling_warning_showed = False def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args) + super().assert_extra_args(args, train_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: @@ -119,10 +119,10 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): encoder_hidden_states2 = [] pool2 = [] for input_id1, input_id2 in zip(input_ids1, input_ids2): - input_id1 = input_id1.squeeze(0) - input_id2 = input_id2.squeeze(0) - encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())]) - hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())] + input_id1_cache_key = tuple(input_id1.flatten().tolist()) + input_id2_cache_key = tuple(input_id2.flatten().tolist()) + encoder_hidden_states1.append(self.text_encoder1_cache[input_id1_cache_key]) + hidden_states2, p2 = self.text_encoder2_cache[input_id2_cache_key] encoder_hidden_states2.append(hidden_states2) pool2.append(p2) encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) diff --git a/train_network.py b/train_network.py index 1a3c19e2..e42225f1 100644 --- a/train_network.py +++ b/train_network.py @@ -83,7 +83,7 @@ class NetworkTrainer: return logs - def assert_extra_args(self, args): + def assert_extra_args(self, args, train_dataset_group): pass def load_target_model(self, args, weight_dtype, accelerator):