From 663b481029014c8fcdffb0a34f4d63130f2cf6a6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 3 Jan 2024 23:22:00 +0900 Subject: [PATCH] fix TI training with full_fp16/bf16 ref #1019 --- train_textual_inversion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f3d71945..545b6ba8 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -616,9 +616,11 @@ class TextualInversionTrainer: for text_encoder, orig_embeds_params, index_no_updates in zip( text_encoders, orig_embeds_params_list, index_no_updates_list ): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 + input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight + input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ index_no_updates - ] = orig_embeds_params[index_no_updates] + ] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: