From 5e96e1369da410e52725e2c7af8b9f9def956c4b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:23:39 +0800 Subject: [PATCH] fix get_hidden_states expected scalar Error --- library/train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index dc0724d7..4c410567 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1560,10 +1560,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - # uncomment code may raise expected scalar type Half but found Float when using DDP - # if weight_dtype is not None: - # # this is required for additional network training - # encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024 @@ -1589,6 +1585,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) return encoder_hidden_states