mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
fix TI training with full_fp16/bf16 ref #1019
This commit is contained in:
@@ -616,9 +616,11 @@ class TextualInversionTrainer:
|
||||
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
||||
text_encoders, orig_embeds_params_list, index_no_updates_list
|
||||
):
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
|
||||
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
|
||||
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
Reference in New Issue
Block a user