fix get_hidden_states expected scalar Error

This commit is contained in:
Isotr0py
2023-02-08 19:23:39 +08:00
parent c0be52a773
commit 5e96e1369d

View File

@@ -1560,10 +1560,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
# uncomment code may raise expected scalar type Half but found Float when using DDP
# if weight_dtype is not None:
# # this is required for additional network training
# encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
@@ -1590,6 +1586,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states