Merge pull request #1452 from fireicewolf/sd3-devel

Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model', while loading T5 model in GPU.
This commit is contained in:
Kohya S.
2024-08-15 21:12:19 +09:00
committed by GitHub

View File

@@ -543,9 +543,13 @@ class NetworkTrainer:
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu": if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype) t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc.text_model, "embeddings"): if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8 # nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) t_enc.text_model.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed: if args.deepspeed: