From 9760d097b0bd7efbeb065d4320b2216a94e76efd Mon Sep 17 00:00:00 2001 From: DukeG Date: Wed, 14 Aug 2024 19:58:54 +0800 Subject: [PATCH] Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model' While loading T5 model in GPU. --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 367203f5..405aa747 100644 --- a/train_network.py +++ b/train_network.py @@ -540,9 +540,13 @@ class NetworkTrainer: # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) - if hasattr(t_enc.text_model, "embeddings"): + if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): + t_enc.encoder.embeddings.to( + dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: