fix text encodes are on gpu even when not trained

This commit is contained in:
Kohya S
2024-01-17 21:31:50 +09:00
parent dcf0eeb5b6
commit 976d092c68
2 changed files with 8 additions and 8 deletions

View File

@@ -95,8 +95,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: