From c0be52a7731b58e305ae22ed26e57af6b7d61f5a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 8 Feb 2023 19:05:39 +0800 Subject: [PATCH] ignore get_hidden_states expected scalar Error --- library/train_util.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6f809deb..dc0724d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1560,9 +1560,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod else: enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # uncomment code may raise expected scalar type Half but found Float when using DDP + # if weight_dtype is not None: + # # this is required for additional network training + # encoder_hidden_states = encoder_hidden_states.to(weight_dtype) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # bs*3, 77, 768 or 1024