fix TI training with full_fp16/bf16 ref #1019

This commit is contained in:
Kohya S
2024-01-03 23:22:00 +09:00
parent 1ab6493268
commit 663b481029

View File

@@ -616,9 +616,11 @@ class TextualInversionTrainer:
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
index_no_updates
] = orig_embeds_params[index_no_updates]
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: