mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix get_hidden_states expected scalar Error
This commit is contained in:
@@ -1560,10 +1560,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
|||||||
else:
|
else:
|
||||||
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
||||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||||
# uncomment code may raise expected scalar type Half but found Float when using DDP
|
|
||||||
# if weight_dtype is not None:
|
|
||||||
# # this is required for additional network training
|
|
||||||
# encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||||
|
|
||||||
# bs*3, 77, 768 or 1024
|
# bs*3, 77, 768 or 1024
|
||||||
@@ -1589,6 +1585,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
|||||||
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
# this is required for additional network training
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||||
|
|
||||||
return encoder_hidden_states
|
return encoder_hidden_states
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user