diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f3d71945..545b6ba8 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -616,9 +616,11 @@ class TextualInversionTrainer: for text_encoder, orig_embeds_params, index_no_updates in zip( text_encoders, orig_embeds_params_list, index_no_updates_list ): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 + input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight + input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ index_no_updates - ] = orig_embeds_params[index_no_updates] + ] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: